apply streaming response to all format

This commit is contained in:
2024-01-13 12:41:40 +08:00
parent cda5691715
commit 939663863b

View File

@@ -1,3 +1,4 @@
from faster_whisper import vad
import tqdm import tqdm
import json import json
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -7,10 +8,19 @@ import io
import hashlib import hashlib
import argparse import argparse
import uvicorn import uvicorn
from typing import Annotated, Any, Literal from typing import Annotated, Any, BinaryIO, Literal, Generator
from fastapi import File, Query, UploadFile, Form, FastAPI, Request, WebSocket, Response from fastapi import (
File,
HTTPException,
Query,
UploadFile,
Form,
FastAPI,
Request,
WebSocket,
)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp from src.whisper_ctranslate2.writers import format_timestamp
import opencc import opencc
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
@@ -50,19 +60,28 @@ app.add_middleware(
) )
def generate_tsv(result: dict[str, list[Any]]): def stream_writer(generator: Generator[dict[str, Any], Any, None]):
tsv = "start\tend\ttext\n" for segment in generator:
for i, segment in enumerate(result["segments"]): yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
yield "data: [DONE]\n\n"
def text_writer(generator: Generator[dict[str, Any], Any, None]):
for segment in generator:
yield segment["text"].strip() + "\n"
def tsv_writer(generator: Generator[dict[str, Any], Any, None]):
yield "start\tend\ttext\n"
for i, segment in enumerate(generator):
start_time = str(round(1000 * segment["start"])) start_time = str(round(1000 * segment["start"]))
end_time = str(round(1000 * segment["end"])) end_time = str(round(1000 * segment["end"]))
text = segment["text"] text = segment["text"]
tsv += f"{start_time}\t{end_time}\t{text}\n" yield f"{start_time}\t{end_time}\t{text}\n"
return tsv
def generate_srt(result: dict[str, list[Any]]): def srt_writer(generator: Generator[dict[str, Any], Any, None]):
srt = "" for i, segment in enumerate(generator):
for i, segment in enumerate(result["segments"], start=1):
start_time = format_timestamp( start_time = format_timestamp(
segment["start"], decimal_marker=",", always_include_hours=True segment["start"], decimal_marker=",", always_include_hours=True
) )
@@ -70,48 +89,74 @@ def generate_srt(result: dict[str, list[Any]]):
segment["end"], decimal_marker=",", always_include_hours=True segment["end"], decimal_marker=",", always_include_hours=True
) )
text = segment["text"] text = segment["text"]
srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n" yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
return srt
def generate_vtt(result: dict[str, list[Any]]): def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
vtt = "WEBVTT\n\n" yield "WEBVTT\n\n"
for segment in result["segments"]: for i, segment in enumerate(generator):
start_time = format_timestamp(segment["start"]) start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"]) end_time = format_timestamp(segment["end"])
text = segment["text"] text = segment["text"]
vtt += f"{start_time} --> {end_time}\n{text}\n\n" yield f"{start_time} --> {end_time}\n{text}\n\n"
return vtt
def get_options(*, initial_prompt=""): def build_json_result(
options = TranscriptionOptions( generator: Generator[dict[str, Any], Any, None]
) -> dict[str, Any]:
segments = [i for i in generator]
return {
"text": "\n".join(i["text"] for i in segments),
"segments": segments,
}
def stream_builder(
audio: BinaryIO,
task: str,
vad_filter: bool,
language: str | None,
initial_prompt: str = "",
):
segments, info = transcriber.model.transcribe(
audio=audio,
language=language,
task=task,
beam_size=5, beam_size=5,
best_of=5, best_of=5,
patience=1.0, patience=1.0,
length_penalty=1.0, length_penalty=-1.0,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
compression_ratio_threshold=2.4,
condition_on_previous_text=True,
temperature=[0.0, 1.0 + 1e-6, 0.2],
suppress_tokens=[],
word_timestamps=True,
print_colors=False,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=False,
vad_threshold=None,
vad_min_speech_duration_ms=None,
vad_max_speech_duration_s=None,
vad_min_silence_duration_ms=None,
initial_prompt=initial_prompt,
repetition_penalty=1.0, repetition_penalty=1.0,
no_repeat_ngram_size=0, no_repeat_ngram_size=0,
temperature=[0.0, 1.0 + 1e-6, 0.2],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
prompt_reset_on_temperature=False, prompt_reset_on_temperature=False,
initial_prompt=initial_prompt,
suppress_blank=False, suppress_blank=False,
suppress_tokens=[],
word_timestamps=True,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=vad_filter,
vad_parameters=None,
) )
return options print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
data["total"] = info.duration
data["text"] = ccc.convert(data["text"])
yield data
@app.websocket("/k6nele/status") @app.websocket("/k6nele/status")
@@ -131,6 +176,7 @@ async def konele_ws(
task: Literal["transcribe", "translate"] = "transcribe", task: Literal["transcribe", "translate"] = "transcribe",
lang: str = "und", lang: str = "und",
initial_prompt: str = "", initial_prompt: str = "",
vad_filter: bool = False,
content_type: Annotated[str, Query(alias="content-type")] = "audio/x-raw", content_type: Annotated[str, Query(alias="content-type")] = "audio/x-raw",
): ):
await websocket.accept() await websocket.accept()
@@ -169,18 +215,16 @@ async def konele_ws(
file_obj.seek(0) file_obj.seek(0)
options = get_options(initial_prompt=initial_prompt) generator = stream_builder(
result = transcriber.inference(
audio=file_obj, audio=file_obj,
task=task, task=task,
language=lang if lang != "und" else None, # type: ignore vad_filter=vad_filter,
verbose=False, language=None if lang == "und" else lang,
live=False, initial_prompt=initial_prompt,
options=options,
) )
result = build_json_result(generator)
text = result.get("text", "") text = result.get("text", "")
text = ccc.convert(text)
print("result", text) print("result", text)
await websocket.send_json( await websocket.send_json(
@@ -201,6 +245,7 @@ async def translateapi(
task: Literal["transcribe", "translate"] = "transcribe", task: Literal["transcribe", "translate"] = "transcribe",
lang: str = "und", lang: str = "und",
initial_prompt: str = "", initial_prompt: str = "",
vad_filter: bool = False,
): ):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
print("downloading request file", content_type) print("downloading request file", content_type)
@@ -234,18 +279,16 @@ async def translateapi(
file_obj.seek(0) file_obj.seek(0)
options = get_options(initial_prompt=initial_prompt) generator = stream_builder(
result = transcriber.inference(
audio=file_obj, audio=file_obj,
task=task, task=task,
language=lang if lang != "und" else None, # type: ignore vad_filter=vad_filter,
verbose=False, language=None if lang == "und" else lang,
live=False, initial_prompt=initial_prompt,
options=options,
) )
result = build_json_result(generator)
text = result.get("text", "") text = result.get("text", "")
text = ccc.convert(text)
print("result", text) print("result", text)
return { return {
@@ -270,84 +313,31 @@ async def transcription(
""" """
# timestamp as filename, keep original extension # timestamp as filename, keep original extension
options = get_options(initial_prompt=prompt) generator = stream_builder(
audio=io.BytesIO(file.file.read()),
task=task,
vad_filter=vad_filter,
language=None if language == "und" else language,
)
# special function for streaming response (OpenAI API does not have this) # special function for streaming response (OpenAI API does not have this)
if response_format == "stream": if response_format == "stream":
def gen():
segments, info = transcriber.model.transcribe(
audio=io.BytesIO(file.file.read()),
language=None if language == "und" else language, # type: ignore
task=task,
beam_size=options.beam_size,
best_of=options.best_of,
patience=options.patience,
length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty,
no_repeat_ngram_size=options.no_repeat_ngram_size,
temperature=options.temperature,
compression_ratio_threshold=options.compression_ratio_threshold,
log_prob_threshold=options.log_prob_threshold,
no_speech_threshold=options.no_speech_threshold,
condition_on_previous_text=options.condition_on_previous_text,
prompt_reset_on_temperature=options.prompt_reset_on_temperature,
initial_prompt=options.initial_prompt,
suppress_blank=options.suppress_blank,
suppress_tokens=options.suppress_tokens,
word_timestamps=True
if options.print_colors
else options.word_timestamps,
prepend_punctuations=options.prepend_punctuations,
append_punctuations=options.append_punctuations,
vad_filter=vad_filter,
vad_parameters=None,
)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
data["total"] = info.duration
data["text"] = ccc.convert(data["text"])
yield "data: " + json.dumps(data, ensure_ascii=False) + "\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
gen(), stream_writer(generator),
media_type="text/event-stream", media_type="text/event-stream",
) )
elif response_format == "json":
result: Any = transcriber.inference( return build_json_result(generator)
audio=io.BytesIO(file.file.read()),
task=task,
language=None if language == "und" else language, # type: ignore
verbose=False,
live=False,
options=options,
)
if response_format == "json":
return result
elif response_format == "text": elif response_format == "text":
return Response( return StreamingResponse(text_writer(generator), media_type="text/plain")
content="\n".join(s["text"] for s in result["segments"]),
media_type="plain/text",
)
elif response_format == "tsv": elif response_format == "tsv":
return Response(content=generate_tsv(result), media_type="plain_text") return StreamingResponse(tsv_writer(generator), media_type="text/plain")
elif response_format == "srt": elif response_format == "srt":
return Response(content=generate_srt(result), media_type="plain_text") return StreamingResponse(srt_writer(generator), media_type="text/plain")
elif response_format == "vtt": elif response_format == "vtt":
return generate_vtt(result) return StreamingResponse(vtt_writer(generator), media_type="text/plain")
return {"error": "Invalid response_format"} raise HTTPException(400, "Invailed response_format")
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)