diff --git a/whisper_fastapi.py b/whisper_fastapi.py index 40f7b7e..7f22cd7 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -22,7 +22,6 @@ from fastapi.middleware.cors import CORSMiddleware from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe from src.whisper_ctranslate2.writers import format_timestamp from faster_whisper.transcribe import Segment, TranscriptionInfo -import opencc from prometheus_fastapi_instrumentator import Instrumentator parser = argparse.ArgumentParser() @@ -35,7 +34,6 @@ args = parser.parse_args() app = FastAPI() # Instrument your app with default metrics and expose the metrics Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics") -ccc = opencc.OpenCC("t2s.json") print("Loading model...") transcriber = Transcribe( @@ -102,12 +100,15 @@ def vtt_writer(generator: Generator[dict[str, Any], Any, None]): def build_json_result( - generator: Generator[dict[str, Any], Any, None] + generator: Iterable[Segment], + info: TranscriptionInfo, ) -> dict[str, Any]: segments = [i for i in generator] return { "text": "\n".join(i["text"] for i in segments), "segments": segments, + "language": info.language, + "language_probability": info.language_probability, } @@ -137,7 +138,6 @@ def stream_builder( last_pos = end data = segment._asdict() data["total"] = info.duration - data["text"] = ccc.convert(data["text"]) yield data return wrap(), info @@ -199,14 +199,14 @@ async def konele_ws( file_obj.seek(0) - generator = stream_builder( + generator, info = stream_builder( audio=file_obj, task=task, vad_filter=vad_filter, language=None if lang == "und" else lang, initial_prompt=initial_prompt, ) - result = build_json_result(generator) + result = build_json_result(generator, info) text = result.get("text", "") print("result", text) @@ -263,14 +263,14 @@ async def translateapi( file_obj.seek(0) - generator = stream_builder( + generator, info = stream_builder( audio=file_obj, task=task, vad_filter=vad_filter, language=None if lang == "und" else lang, initial_prompt=initial_prompt, ) - result = build_json_result(generator) + result = build_json_result(generator, info) text = result.get("text", "") print("result", text) @@ -311,7 +311,7 @@ async def transcription( media_type="text/event-stream", ) elif response_format == "json": - return build_json_result(generator) + return build_json_result(generator, info) elif response_format == "text": return StreamingResponse(text_writer(generator), media_type="text/plain") elif response_format == "tsv":