From cda5691715749bd0d2e7e16dbb7bec7ac4eeba41 Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Mon, 25 Dec 2023 18:26:40 +0800 Subject: [PATCH] add: openai streaming response --- whisper_fastapi.py | 62 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/whisper_fastapi.py b/whisper_fastapi.py index b3a5c0d..b4d9adc 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -1,3 +1,6 @@ +import tqdm +import json +from fastapi.responses import StreamingResponse import wave import pydub import io @@ -257,6 +260,9 @@ async def transcription( file: UploadFile = File(...), prompt: str = Form(""), response_format: str = Form("json"), + task: str = Form("transcribe"), + language: str = Form("und"), + vad_filter: bool = Form(False), ): """Transcription endpoint @@ -266,10 +272,62 @@ async def transcription( # timestamp as filename, keep original extension options = get_options(initial_prompt=prompt) + # special function for streaming response (OpenAI API does not have this) + if response_format == "stream": + + def gen(): + segments, info = transcriber.model.transcribe( + audio=io.BytesIO(file.file.read()), + language=None if language == "und" else language, # type: ignore + task=task, + beam_size=options.beam_size, + best_of=options.best_of, + patience=options.patience, + length_penalty=options.length_penalty, + repetition_penalty=options.repetition_penalty, + no_repeat_ngram_size=options.no_repeat_ngram_size, + temperature=options.temperature, + compression_ratio_threshold=options.compression_ratio_threshold, + log_prob_threshold=options.log_prob_threshold, + no_speech_threshold=options.no_speech_threshold, + condition_on_previous_text=options.condition_on_previous_text, + prompt_reset_on_temperature=options.prompt_reset_on_temperature, + initial_prompt=options.initial_prompt, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, + word_timestamps=True + if options.print_colors + else options.word_timestamps, + prepend_punctuations=options.prepend_punctuations, + append_punctuations=options.append_punctuations, + vad_filter=vad_filter, + vad_parameters=None, + ) + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + last_pos = 0 + with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar: + for segment in segments: + start, end, text = segment.start, segment.end, segment.text + pbar.update(end - last_pos) + last_pos = end + data = segment._asdict() + data["total"] = info.duration + data["text"] = ccc.convert(data["text"]) + yield "data: " + json.dumps(data, ensure_ascii=False) + "\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + gen(), + media_type="text/event-stream", + ) + result: Any = transcriber.inference( audio=io.BytesIO(file.file.read()), - task="transcribe", - language=None, # type: ignore + task=task, + language=None if language == "und" else language, # type: ignore verbose=False, live=False, options=options,