add metrics, fix response type

This commit is contained in:
2023-11-07 22:06:36 +08:00
parent 046f4017d0
commit 144f95cf37
2 changed files with 16 additions and 12 deletions

View File

@@ -3,3 +3,4 @@ fastapi
uvicorn uvicorn
whisper_ctranslate2 whisper_ctranslate2
opencc opencc
prometheus-fastapi-instrumentator

View File

@@ -4,11 +4,12 @@ import hashlib
import argparse import argparse
import uvicorn import uvicorn
from typing import Any from typing import Any
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
from src.whisper_ctranslate2.writers import format_timestamp from src.whisper_ctranslate2.writers import format_timestamp
import opencc import opencc
from prometheus_fastapi_instrumentator import Instrumentator
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str) parser.add_argument("--host", default="0.0.0.0", type=str)
@@ -17,6 +18,8 @@ parser.add_argument("--model", default="large-v2", type=str)
parser.add_argument("--cache_dir", default=None, type=str) parser.add_argument("--cache_dir", default=None, type=str)
args = parser.parse_args() args = parser.parse_args()
app = FastAPI() app = FastAPI()
# Instrument your app with default metrics and expose the metrics
Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
ccc = opencc.OpenCC("t2s.json") ccc = opencc.OpenCC("t2s.json")
print("Loading model...") print("Loading model...")
@@ -211,7 +214,7 @@ async def translateapi(
async def transcription( async def transcription(
file: UploadFile = File(...), file: UploadFile = File(...),
prompt: str = Form(""), prompt: str = Form(""),
response_type: str = Form("json"), response_format: str = Form("json"),
): ):
"""Transcription endpoint """Transcription endpoint
@@ -230,18 +233,18 @@ async def transcription(
options=options, options=options,
) )
if response_type == "json": if response_format == "json":
return result return Response(content=result, media_type="application/json")
elif response_type == "text": elif response_format == "text":
return result["text"].strip() return Response(content='\n'.join(s['text'] for s in result['segments']), media_type="plain/text")
elif response_type == "tsv": elif response_format == "tsv":
return generate_tsv(result) return Response(content=generate_tsv(result), media_type='plain_text')
elif response_type == "srt": elif response_format == "srt":
return generate_srt(result) return Response(content=generate_srt(result), media_type='plain_text')
elif response_type == "vtt": elif response_format == "vtt":
return generate_vtt(result) return generate_vtt(result)
return {"error": "Invalid response_type"} return {"error": "Invalid response_format"}
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)