diff --git a/whisper_fastapi.py b/whisper_fastapi.py index b4d9adc..605b13d 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -1,3 +1,4 @@ +from faster_whisper import vad import tqdm import json from fastapi.responses import StreamingResponse @@ -7,10 +8,19 @@ import io import hashlib import argparse import uvicorn -from typing import Annotated, Any, Literal -from fastapi import File, Query, UploadFile, Form, FastAPI, Request, WebSocket, Response +from typing import Annotated, Any, BinaryIO, Literal, Generator +from fastapi import ( + File, + HTTPException, + Query, + UploadFile, + Form, + FastAPI, + Request, + WebSocket, +) from fastapi.middleware.cors import CORSMiddleware -from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions +from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe from src.whisper_ctranslate2.writers import format_timestamp import opencc from prometheus_fastapi_instrumentator import Instrumentator @@ -50,19 +60,28 @@ app.add_middleware( ) -def generate_tsv(result: dict[str, list[Any]]): - tsv = "start\tend\ttext\n" - for i, segment in enumerate(result["segments"]): +def stream_writer(generator: Generator[dict[str, Any], Any, None]): + for segment in generator: + yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n" + yield "data: [DONE]\n\n" + + +def text_writer(generator: Generator[dict[str, Any], Any, None]): + for segment in generator: + yield segment["text"].strip() + "\n" + + +def tsv_writer(generator: Generator[dict[str, Any], Any, None]): + yield "start\tend\ttext\n" + for i, segment in enumerate(generator): start_time = str(round(1000 * segment["start"])) end_time = str(round(1000 * segment["end"])) text = segment["text"] - tsv += f"{start_time}\t{end_time}\t{text}\n" - return tsv + yield f"{start_time}\t{end_time}\t{text}\n" -def generate_srt(result: dict[str, list[Any]]): - srt = "" - for i, segment in enumerate(result["segments"], start=1): +def srt_writer(generator: Generator[dict[str, Any], Any, None]): + for i, segment in enumerate(generator): start_time = format_timestamp( segment["start"], decimal_marker=",", always_include_hours=True ) @@ -70,48 +89,74 @@ def generate_srt(result: dict[str, list[Any]]): segment["end"], decimal_marker=",", always_include_hours=True ) text = segment["text"] - srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n" - return srt + yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n" -def generate_vtt(result: dict[str, list[Any]]): - vtt = "WEBVTT\n\n" - for segment in result["segments"]: +def vtt_writer(generator: Generator[dict[str, Any], Any, None]): + yield "WEBVTT\n\n" + for i, segment in enumerate(generator): start_time = format_timestamp(segment["start"]) end_time = format_timestamp(segment["end"]) text = segment["text"] - vtt += f"{start_time} --> {end_time}\n{text}\n\n" - return vtt + yield f"{start_time} --> {end_time}\n{text}\n\n" -def get_options(*, initial_prompt=""): - options = TranscriptionOptions( +def build_json_result( + generator: Generator[dict[str, Any], Any, None] +) -> dict[str, Any]: + segments = [i for i in generator] + return { + "text": "\n".join(i["text"] for i in segments), + "segments": segments, + } + + +def stream_builder( + audio: BinaryIO, + task: str, + vad_filter: bool, + language: str | None, + initial_prompt: str = "", +): + segments, info = transcriber.model.transcribe( + audio=audio, + language=language, + task=task, beam_size=5, best_of=5, patience=1.0, - length_penalty=1.0, - log_prob_threshold=-1.0, - no_speech_threshold=0.6, - compression_ratio_threshold=2.4, - condition_on_previous_text=True, - temperature=[0.0, 1.0 + 1e-6, 0.2], - suppress_tokens=[], - word_timestamps=True, - print_colors=False, - prepend_punctuations="\"'“¿([{-", - append_punctuations="\"'.。,,!!??::”)]}、", - vad_filter=False, - vad_threshold=None, - vad_min_speech_duration_ms=None, - vad_max_speech_duration_s=None, - vad_min_silence_duration_ms=None, - initial_prompt=initial_prompt, + length_penalty=-1.0, repetition_penalty=1.0, no_repeat_ngram_size=0, + temperature=[0.0, 1.0 + 1e-6, 0.2], + compression_ratio_threshold=2.4, + log_prob_threshold=-1.0, + no_speech_threshold=0.6, + condition_on_previous_text=True, prompt_reset_on_temperature=False, + initial_prompt=initial_prompt, suppress_blank=False, + suppress_tokens=[], + word_timestamps=True, + prepend_punctuations="\"'“¿([{-", + append_punctuations="\"'.。,,!!??::”)]}、", + vad_filter=vad_filter, + vad_parameters=None, ) - return options + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + last_pos = 0 + with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar: + for segment in segments: + start, end, text = segment.start, segment.end, segment.text + pbar.update(end - last_pos) + last_pos = end + data = segment._asdict() + data["total"] = info.duration + data["text"] = ccc.convert(data["text"]) + yield data @app.websocket("/k6nele/status") @@ -131,6 +176,7 @@ async def konele_ws( task: Literal["transcribe", "translate"] = "transcribe", lang: str = "und", initial_prompt: str = "", + vad_filter: bool = False, content_type: Annotated[str, Query(alias="content-type")] = "audio/x-raw", ): await websocket.accept() @@ -169,18 +215,16 @@ async def konele_ws( file_obj.seek(0) - options = get_options(initial_prompt=initial_prompt) - - result = transcriber.inference( + generator = stream_builder( audio=file_obj, task=task, - language=lang if lang != "und" else None, # type: ignore - verbose=False, - live=False, - options=options, + vad_filter=vad_filter, + language=None if lang == "und" else lang, + initial_prompt=initial_prompt, ) + result = build_json_result(generator) + text = result.get("text", "") - text = ccc.convert(text) print("result", text) await websocket.send_json( @@ -201,6 +245,7 @@ async def translateapi( task: Literal["transcribe", "translate"] = "transcribe", lang: str = "und", initial_prompt: str = "", + vad_filter: bool = False, ): content_type = request.headers.get("Content-Type", "") print("downloading request file", content_type) @@ -234,18 +279,16 @@ async def translateapi( file_obj.seek(0) - options = get_options(initial_prompt=initial_prompt) - - result = transcriber.inference( + generator = stream_builder( audio=file_obj, task=task, - language=lang if lang != "und" else None, # type: ignore - verbose=False, - live=False, - options=options, + vad_filter=vad_filter, + language=None if lang == "und" else lang, + initial_prompt=initial_prompt, ) + result = build_json_result(generator) + text = result.get("text", "") - text = ccc.convert(text) print("result", text) return { @@ -270,84 +313,31 @@ async def transcription( """ # timestamp as filename, keep original extension - options = get_options(initial_prompt=prompt) + generator = stream_builder( + audio=io.BytesIO(file.file.read()), + task=task, + vad_filter=vad_filter, + language=None if language == "und" else language, + ) # special function for streaming response (OpenAI API does not have this) if response_format == "stream": - - def gen(): - segments, info = transcriber.model.transcribe( - audio=io.BytesIO(file.file.read()), - language=None if language == "und" else language, # type: ignore - task=task, - beam_size=options.beam_size, - best_of=options.best_of, - patience=options.patience, - length_penalty=options.length_penalty, - repetition_penalty=options.repetition_penalty, - no_repeat_ngram_size=options.no_repeat_ngram_size, - temperature=options.temperature, - compression_ratio_threshold=options.compression_ratio_threshold, - log_prob_threshold=options.log_prob_threshold, - no_speech_threshold=options.no_speech_threshold, - condition_on_previous_text=options.condition_on_previous_text, - prompt_reset_on_temperature=options.prompt_reset_on_temperature, - initial_prompt=options.initial_prompt, - suppress_blank=options.suppress_blank, - suppress_tokens=options.suppress_tokens, - word_timestamps=True - if options.print_colors - else options.word_timestamps, - prepend_punctuations=options.prepend_punctuations, - append_punctuations=options.append_punctuations, - vad_filter=vad_filter, - vad_parameters=None, - ) - print( - "Detected language '%s' with probability %f" - % (info.language, info.language_probability) - ) - last_pos = 0 - with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar: - for segment in segments: - start, end, text = segment.start, segment.end, segment.text - pbar.update(end - last_pos) - last_pos = end - data = segment._asdict() - data["total"] = info.duration - data["text"] = ccc.convert(data["text"]) - yield "data: " + json.dumps(data, ensure_ascii=False) + "\n\n" - yield "data: [DONE]\n\n" - return StreamingResponse( - gen(), + stream_writer(generator), media_type="text/event-stream", ) - - result: Any = transcriber.inference( - audio=io.BytesIO(file.file.read()), - task=task, - language=None if language == "und" else language, # type: ignore - verbose=False, - live=False, - options=options, - ) - - if response_format == "json": - return result + elif response_format == "json": + return build_json_result(generator) elif response_format == "text": - return Response( - content="\n".join(s["text"] for s in result["segments"]), - media_type="plain/text", - ) + return StreamingResponse(text_writer(generator), media_type="text/plain") elif response_format == "tsv": - return Response(content=generate_tsv(result), media_type="plain_text") + return StreamingResponse(tsv_writer(generator), media_type="text/plain") elif response_format == "srt": - return Response(content=generate_srt(result), media_type="plain_text") + return StreamingResponse(srt_writer(generator), media_type="text/plain") elif response_format == "vtt": - return generate_vtt(result) + return StreamingResponse(vtt_writer(generator), media_type="text/plain") - return {"error": "Invalid response_format"} + raise HTTPException(400, "Invailed response_format") uvicorn.run(app, host=args.host, port=args.port)