diff --git a/requirements.txt b/requirements.txt index 77e1bec..f89445c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ fastapi uvicorn whisper_ctranslate2 opencc +prometheus-fastapi-instrumentator diff --git a/whisper_fastapi.py b/whisper_fastapi.py index 16f1355..aa7039e 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -4,11 +4,12 @@ import hashlib import argparse import uvicorn from typing import Any -from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket +from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket, Response from fastapi.middleware.cors import CORSMiddleware from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions from src.whisper_ctranslate2.writers import format_timestamp import opencc +from prometheus_fastapi_instrumentator import Instrumentator parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", type=str) @@ -17,6 +18,8 @@ parser.add_argument("--model", default="large-v2", type=str) parser.add_argument("--cache_dir", default=None, type=str) args = parser.parse_args() app = FastAPI() +# Instrument your app with default metrics and expose the metrics +Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics") ccc = opencc.OpenCC("t2s.json") print("Loading model...") @@ -211,7 +214,7 @@ async def translateapi( async def transcription( file: UploadFile = File(...), prompt: str = Form(""), - response_type: str = Form("json"), + response_format: str = Form("json"), ): """Transcription endpoint @@ -230,18 +233,18 @@ async def transcription( options=options, ) - if response_type == "json": - return result - elif response_type == "text": - return result["text"].strip() - elif response_type == "tsv": - return generate_tsv(result) - elif response_type == "srt": - return generate_srt(result) - elif response_type == "vtt": + if response_format == "json": + return Response(content=result, media_type="application/json") + elif response_format == "text": + return Response(content='\n'.join(s['text'] for s in result['segments']), media_type="plain/text") + elif response_format == "tsv": + return Response(content=generate_tsv(result), media_type='plain_text') + elif response_format == "srt": + return Response(content=generate_srt(result), media_type='plain_text') + elif response_format == "vtt": return generate_vtt(result) - return {"error": "Invalid response_type"} + return {"error": "Invalid response_format"} uvicorn.run(app, host=args.host, port=args.port)