add: openai streaming response
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
import tqdm
|
||||
import json
|
||||
from fastapi.responses import StreamingResponse
|
||||
import wave
|
||||
import pydub
|
||||
import io
|
||||
@@ -257,6 +260,9 @@ async def transcription(
|
||||
file: UploadFile = File(...),
|
||||
prompt: str = Form(""),
|
||||
response_format: str = Form("json"),
|
||||
task: str = Form("transcribe"),
|
||||
language: str = Form("und"),
|
||||
vad_filter: bool = Form(False),
|
||||
):
|
||||
"""Transcription endpoint
|
||||
|
||||
@@ -266,10 +272,62 @@ async def transcription(
|
||||
# timestamp as filename, keep original extension
|
||||
options = get_options(initial_prompt=prompt)
|
||||
|
||||
# special function for streaming response (OpenAI API does not have this)
|
||||
if response_format == "stream":
|
||||
|
||||
def gen():
|
||||
segments, info = transcriber.model.transcribe(
|
||||
audio=io.BytesIO(file.file.read()),
|
||||
language=None if language == "und" else language, # type: ignore
|
||||
task=task,
|
||||
beam_size=options.beam_size,
|
||||
best_of=options.best_of,
|
||||
patience=options.patience,
|
||||
length_penalty=options.length_penalty,
|
||||
repetition_penalty=options.repetition_penalty,
|
||||
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
||||
temperature=options.temperature,
|
||||
compression_ratio_threshold=options.compression_ratio_threshold,
|
||||
log_prob_threshold=options.log_prob_threshold,
|
||||
no_speech_threshold=options.no_speech_threshold,
|
||||
condition_on_previous_text=options.condition_on_previous_text,
|
||||
prompt_reset_on_temperature=options.prompt_reset_on_temperature,
|
||||
initial_prompt=options.initial_prompt,
|
||||
suppress_blank=options.suppress_blank,
|
||||
suppress_tokens=options.suppress_tokens,
|
||||
word_timestamps=True
|
||||
if options.print_colors
|
||||
else options.word_timestamps,
|
||||
prepend_punctuations=options.prepend_punctuations,
|
||||
append_punctuations=options.append_punctuations,
|
||||
vad_filter=vad_filter,
|
||||
vad_parameters=None,
|
||||
)
|
||||
print(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
last_pos = 0
|
||||
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
|
||||
for segment in segments:
|
||||
start, end, text = segment.start, segment.end, segment.text
|
||||
pbar.update(end - last_pos)
|
||||
last_pos = end
|
||||
data = segment._asdict()
|
||||
data["total"] = info.duration
|
||||
data["text"] = ccc.convert(data["text"])
|
||||
yield "data: " + json.dumps(data, ensure_ascii=False) + "\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
gen(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
result: Any = transcriber.inference(
|
||||
audio=io.BytesIO(file.file.read()),
|
||||
task="transcribe",
|
||||
language=None, # type: ignore
|
||||
task=task,
|
||||
language=None if language == "und" else language, # type: ignore
|
||||
verbose=False,
|
||||
live=False,
|
||||
options=options,
|
||||
|
||||
Reference in New Issue
Block a user