add metrics, fix response type

This commit is contained in:
2023-11-07 22:06:36 +08:00
parent 046f4017d0
commit 144f95cf37
2 changed files with 16 additions and 12 deletions

View File

@@ -4,11 +4,12 @@ import hashlib
import argparse
import uvicorn
from typing import Any
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket, Response
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
from src.whisper_ctranslate2.writers import format_timestamp
import opencc
from prometheus_fastapi_instrumentator import Instrumentator
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str)
@@ -17,6 +18,8 @@ parser.add_argument("--model", default="large-v2", type=str)
parser.add_argument("--cache_dir", default=None, type=str)
args = parser.parse_args()
app = FastAPI()
# Instrument your app with default metrics and expose the metrics
Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
ccc = opencc.OpenCC("t2s.json")
print("Loading model...")
@@ -211,7 +214,7 @@ async def translateapi(
async def transcription(
file: UploadFile = File(...),
prompt: str = Form(""),
response_type: str = Form("json"),
response_format: str = Form("json"),
):
"""Transcription endpoint
@@ -230,18 +233,18 @@ async def transcription(
options=options,
)
if response_type == "json":
return result
elif response_type == "text":
return result["text"].strip()
elif response_type == "tsv":
return generate_tsv(result)
elif response_type == "srt":
return generate_srt(result)
elif response_type == "vtt":
if response_format == "json":
return Response(content=result, media_type="application/json")
elif response_format == "text":
return Response(content='\n'.join(s['text'] for s in result['segments']), media_type="plain/text")
elif response_format == "tsv":
return Response(content=generate_tsv(result), media_type='plain_text')
elif response_format == "srt":
return Response(content=generate_srt(result), media_type='plain_text')
elif response_format == "vtt":
return generate_vtt(result)
return {"error": "Invalid response_type"}
return {"error": "Invalid response_format"}
uvicorn.run(app, host=args.host, port=args.port)