add metrics, fix response type
This commit is contained in:
@@ -3,3 +3,4 @@ fastapi
|
||||
uvicorn
|
||||
whisper_ctranslate2
|
||||
opencc
|
||||
prometheus-fastapi-instrumentator
|
||||
|
||||
@@ -4,11 +4,12 @@ import hashlib
|
||||
import argparse
|
||||
import uvicorn
|
||||
from typing import Any
|
||||
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket
|
||||
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
|
||||
from src.whisper_ctranslate2.writers import format_timestamp
|
||||
import opencc
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="0.0.0.0", type=str)
|
||||
@@ -17,6 +18,8 @@ parser.add_argument("--model", default="large-v2", type=str)
|
||||
parser.add_argument("--cache_dir", default=None, type=str)
|
||||
args = parser.parse_args()
|
||||
app = FastAPI()
|
||||
# Instrument your app with default metrics and expose the metrics
|
||||
Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
|
||||
ccc = opencc.OpenCC("t2s.json")
|
||||
|
||||
print("Loading model...")
|
||||
@@ -211,7 +214,7 @@ async def translateapi(
|
||||
async def transcription(
|
||||
file: UploadFile = File(...),
|
||||
prompt: str = Form(""),
|
||||
response_type: str = Form("json"),
|
||||
response_format: str = Form("json"),
|
||||
):
|
||||
"""Transcription endpoint
|
||||
|
||||
@@ -230,18 +233,18 @@ async def transcription(
|
||||
options=options,
|
||||
)
|
||||
|
||||
if response_type == "json":
|
||||
return result
|
||||
elif response_type == "text":
|
||||
return result["text"].strip()
|
||||
elif response_type == "tsv":
|
||||
return generate_tsv(result)
|
||||
elif response_type == "srt":
|
||||
return generate_srt(result)
|
||||
elif response_type == "vtt":
|
||||
if response_format == "json":
|
||||
return Response(content=result, media_type="application/json")
|
||||
elif response_format == "text":
|
||||
return Response(content='\n'.join(s['text'] for s in result['segments']), media_type="plain/text")
|
||||
elif response_format == "tsv":
|
||||
return Response(content=generate_tsv(result), media_type='plain_text')
|
||||
elif response_format == "srt":
|
||||
return Response(content=generate_srt(result), media_type='plain_text')
|
||||
elif response_format == "vtt":
|
||||
return generate_vtt(result)
|
||||
|
||||
return {"error": "Invalid response_type"}
|
||||
return {"error": "Invalid response_format"}
|
||||
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
Reference in New Issue
Block a user