apply streaming response to all format

This commit is contained in:
2024-01-13 12:41:40 +08:00
parent cda5691715
commit 939663863b

View File

@@ -1,3 +1,4 @@
from faster_whisper import vad
import tqdm
import json
from fastapi.responses import StreamingResponse
@@ -7,10 +8,19 @@ import io
import hashlib
import argparse
import uvicorn
from typing import Annotated, Any, Literal
from fastapi import File, Query, UploadFile, Form, FastAPI, Request, WebSocket, Response
from typing import Annotated, Any, BinaryIO, Literal, Generator
from fastapi import (
File,
HTTPException,
Query,
UploadFile,
Form,
FastAPI,
Request,
WebSocket,
)
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp
import opencc
from prometheus_fastapi_instrumentator import Instrumentator
@@ -50,19 +60,28 @@ app.add_middleware(
)
def generate_tsv(result: dict[str, list[Any]]):
tsv = "start\tend\ttext\n"
for i, segment in enumerate(result["segments"]):
def stream_writer(generator: Generator[dict[str, Any], Any, None]):
for segment in generator:
yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
yield "data: [DONE]\n\n"
def text_writer(generator: Generator[dict[str, Any], Any, None]):
for segment in generator:
yield segment["text"].strip() + "\n"
def tsv_writer(generator: Generator[dict[str, Any], Any, None]):
yield "start\tend\ttext\n"
for i, segment in enumerate(generator):
start_time = str(round(1000 * segment["start"]))
end_time = str(round(1000 * segment["end"]))
text = segment["text"]
tsv += f"{start_time}\t{end_time}\t{text}\n"
return tsv
yield f"{start_time}\t{end_time}\t{text}\n"
def generate_srt(result: dict[str, list[Any]]):
srt = ""
for i, segment in enumerate(result["segments"], start=1):
def srt_writer(generator: Generator[dict[str, Any], Any, None]):
for i, segment in enumerate(generator):
start_time = format_timestamp(
segment["start"], decimal_marker=",", always_include_hours=True
)
@@ -70,48 +89,74 @@ def generate_srt(result: dict[str, list[Any]]):
segment["end"], decimal_marker=",", always_include_hours=True
)
text = segment["text"]
srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
return srt
yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
def generate_vtt(result: dict[str, list[Any]]):
vtt = "WEBVTT\n\n"
for segment in result["segments"]:
def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
yield "WEBVTT\n\n"
for i, segment in enumerate(generator):
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"]
vtt += f"{start_time} --> {end_time}\n{text}\n\n"
return vtt
yield f"{start_time} --> {end_time}\n{text}\n\n"
def get_options(*, initial_prompt=""):
options = TranscriptionOptions(
def build_json_result(
generator: Generator[dict[str, Any], Any, None]
) -> dict[str, Any]:
segments = [i for i in generator]
return {
"text": "\n".join(i["text"] for i in segments),
"segments": segments,
}
def stream_builder(
audio: BinaryIO,
task: str,
vad_filter: bool,
language: str | None,
initial_prompt: str = "",
):
segments, info = transcriber.model.transcribe(
audio=audio,
language=language,
task=task,
beam_size=5,
best_of=5,
patience=1.0,
length_penalty=1.0,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
compression_ratio_threshold=2.4,
condition_on_previous_text=True,
temperature=[0.0, 1.0 + 1e-6, 0.2],
suppress_tokens=[],
word_timestamps=True,
print_colors=False,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=False,
vad_threshold=None,
vad_min_speech_duration_ms=None,
vad_max_speech_duration_s=None,
vad_min_silence_duration_ms=None,
initial_prompt=initial_prompt,
length_penalty=-1.0,
repetition_penalty=1.0,
no_repeat_ngram_size=0,
temperature=[0.0, 1.0 + 1e-6, 0.2],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
prompt_reset_on_temperature=False,
initial_prompt=initial_prompt,
suppress_blank=False,
suppress_tokens=[],
word_timestamps=True,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=vad_filter,
vad_parameters=None,
)
return options
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
data["total"] = info.duration
data["text"] = ccc.convert(data["text"])
yield data
@app.websocket("/k6nele/status")
@@ -131,6 +176,7 @@ async def konele_ws(
task: Literal["transcribe", "translate"] = "transcribe",
lang: str = "und",
initial_prompt: str = "",
vad_filter: bool = False,
content_type: Annotated[str, Query(alias="content-type")] = "audio/x-raw",
):
await websocket.accept()
@@ -169,18 +215,16 @@ async def konele_ws(
file_obj.seek(0)
options = get_options(initial_prompt=initial_prompt)
result = transcriber.inference(
generator = stream_builder(
audio=file_obj,
task=task,
language=lang if lang != "und" else None, # type: ignore
verbose=False,
live=False,
options=options,
vad_filter=vad_filter,
language=None if lang == "und" else lang,
initial_prompt=initial_prompt,
)
result = build_json_result(generator)
text = result.get("text", "")
text = ccc.convert(text)
print("result", text)
await websocket.send_json(
@@ -201,6 +245,7 @@ async def translateapi(
task: Literal["transcribe", "translate"] = "transcribe",
lang: str = "und",
initial_prompt: str = "",
vad_filter: bool = False,
):
content_type = request.headers.get("Content-Type", "")
print("downloading request file", content_type)
@@ -234,18 +279,16 @@ async def translateapi(
file_obj.seek(0)
options = get_options(initial_prompt=initial_prompt)
result = transcriber.inference(
generator = stream_builder(
audio=file_obj,
task=task,
language=lang if lang != "und" else None, # type: ignore
verbose=False,
live=False,
options=options,
vad_filter=vad_filter,
language=None if lang == "und" else lang,
initial_prompt=initial_prompt,
)
result = build_json_result(generator)
text = result.get("text", "")
text = ccc.convert(text)
print("result", text)
return {
@@ -270,84 +313,31 @@ async def transcription(
"""
# timestamp as filename, keep original extension
options = get_options(initial_prompt=prompt)
generator = stream_builder(
audio=io.BytesIO(file.file.read()),
task=task,
vad_filter=vad_filter,
language=None if language == "und" else language,
)
# special function for streaming response (OpenAI API does not have this)
if response_format == "stream":
def gen():
segments, info = transcriber.model.transcribe(
audio=io.BytesIO(file.file.read()),
language=None if language == "und" else language, # type: ignore
task=task,
beam_size=options.beam_size,
best_of=options.best_of,
patience=options.patience,
length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty,
no_repeat_ngram_size=options.no_repeat_ngram_size,
temperature=options.temperature,
compression_ratio_threshold=options.compression_ratio_threshold,
log_prob_threshold=options.log_prob_threshold,
no_speech_threshold=options.no_speech_threshold,
condition_on_previous_text=options.condition_on_previous_text,
prompt_reset_on_temperature=options.prompt_reset_on_temperature,
initial_prompt=options.initial_prompt,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
word_timestamps=True
if options.print_colors
else options.word_timestamps,
prepend_punctuations=options.prepend_punctuations,
append_punctuations=options.append_punctuations,
vad_filter=vad_filter,
vad_parameters=None,
)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
data["total"] = info.duration
data["text"] = ccc.convert(data["text"])
yield "data: " + json.dumps(data, ensure_ascii=False) + "\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
gen(),
stream_writer(generator),
media_type="text/event-stream",
)
result: Any = transcriber.inference(
audio=io.BytesIO(file.file.read()),
task=task,
language=None if language == "und" else language, # type: ignore
verbose=False,
live=False,
options=options,
)
if response_format == "json":
return result
elif response_format == "json":
return build_json_result(generator)
elif response_format == "text":
return Response(
content="\n".join(s["text"] for s in result["segments"]),
media_type="plain/text",
)
return StreamingResponse(text_writer(generator), media_type="text/plain")
elif response_format == "tsv":
return Response(content=generate_tsv(result), media_type="plain_text")
return StreamingResponse(tsv_writer(generator), media_type="text/plain")
elif response_format == "srt":
return Response(content=generate_srt(result), media_type="plain_text")
return StreamingResponse(srt_writer(generator), media_type="text/plain")
elif response_format == "vtt":
return generate_vtt(result)
return StreamingResponse(vtt_writer(generator), media_type="text/plain")
return {"error": "Invalid response_format"}
raise HTTPException(400, "Invailed response_format")
uvicorn.run(app, host=args.host, port=args.port)