apply streaming response to all format
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
from faster_whisper import vad
|
||||||
import tqdm
|
import tqdm
|
||||||
import json
|
import json
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
@@ -7,10 +8,19 @@ import io
|
|||||||
import hashlib
|
import hashlib
|
||||||
import argparse
|
import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, BinaryIO, Literal, Generator
|
||||||
from fastapi import File, Query, UploadFile, Form, FastAPI, Request, WebSocket, Response
|
from fastapi import (
|
||||||
|
File,
|
||||||
|
HTTPException,
|
||||||
|
Query,
|
||||||
|
UploadFile,
|
||||||
|
Form,
|
||||||
|
FastAPI,
|
||||||
|
Request,
|
||||||
|
WebSocket,
|
||||||
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
|
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
|
||||||
from src.whisper_ctranslate2.writers import format_timestamp
|
from src.whisper_ctranslate2.writers import format_timestamp
|
||||||
import opencc
|
import opencc
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
@@ -50,19 +60,28 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_tsv(result: dict[str, list[Any]]):
|
def stream_writer(generator: Generator[dict[str, Any], Any, None]):
|
||||||
tsv = "start\tend\ttext\n"
|
for segment in generator:
|
||||||
for i, segment in enumerate(result["segments"]):
|
yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def text_writer(generator: Generator[dict[str, Any], Any, None]):
|
||||||
|
for segment in generator:
|
||||||
|
yield segment["text"].strip() + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def tsv_writer(generator: Generator[dict[str, Any], Any, None]):
|
||||||
|
yield "start\tend\ttext\n"
|
||||||
|
for i, segment in enumerate(generator):
|
||||||
start_time = str(round(1000 * segment["start"]))
|
start_time = str(round(1000 * segment["start"]))
|
||||||
end_time = str(round(1000 * segment["end"]))
|
end_time = str(round(1000 * segment["end"]))
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
tsv += f"{start_time}\t{end_time}\t{text}\n"
|
yield f"{start_time}\t{end_time}\t{text}\n"
|
||||||
return tsv
|
|
||||||
|
|
||||||
|
|
||||||
def generate_srt(result: dict[str, list[Any]]):
|
def srt_writer(generator: Generator[dict[str, Any], Any, None]):
|
||||||
srt = ""
|
for i, segment in enumerate(generator):
|
||||||
for i, segment in enumerate(result["segments"], start=1):
|
|
||||||
start_time = format_timestamp(
|
start_time = format_timestamp(
|
||||||
segment["start"], decimal_marker=",", always_include_hours=True
|
segment["start"], decimal_marker=",", always_include_hours=True
|
||||||
)
|
)
|
||||||
@@ -70,48 +89,74 @@ def generate_srt(result: dict[str, list[Any]]):
|
|||||||
segment["end"], decimal_marker=",", always_include_hours=True
|
segment["end"], decimal_marker=",", always_include_hours=True
|
||||||
)
|
)
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
|
yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
|
||||||
return srt
|
|
||||||
|
|
||||||
|
|
||||||
def generate_vtt(result: dict[str, list[Any]]):
|
def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
|
||||||
vtt = "WEBVTT\n\n"
|
yield "WEBVTT\n\n"
|
||||||
for segment in result["segments"]:
|
for i, segment in enumerate(generator):
|
||||||
start_time = format_timestamp(segment["start"])
|
start_time = format_timestamp(segment["start"])
|
||||||
end_time = format_timestamp(segment["end"])
|
end_time = format_timestamp(segment["end"])
|
||||||
text = segment["text"]
|
text = segment["text"]
|
||||||
vtt += f"{start_time} --> {end_time}\n{text}\n\n"
|
yield f"{start_time} --> {end_time}\n{text}\n\n"
|
||||||
return vtt
|
|
||||||
|
|
||||||
|
|
||||||
def get_options(*, initial_prompt=""):
|
def build_json_result(
|
||||||
options = TranscriptionOptions(
|
generator: Generator[dict[str, Any], Any, None]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
segments = [i for i in generator]
|
||||||
|
return {
|
||||||
|
"text": "\n".join(i["text"] for i in segments),
|
||||||
|
"segments": segments,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def stream_builder(
|
||||||
|
audio: BinaryIO,
|
||||||
|
task: str,
|
||||||
|
vad_filter: bool,
|
||||||
|
language: str | None,
|
||||||
|
initial_prompt: str = "",
|
||||||
|
):
|
||||||
|
segments, info = transcriber.model.transcribe(
|
||||||
|
audio=audio,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
beam_size=5,
|
beam_size=5,
|
||||||
best_of=5,
|
best_of=5,
|
||||||
patience=1.0,
|
patience=1.0,
|
||||||
length_penalty=1.0,
|
length_penalty=-1.0,
|
||||||
log_prob_threshold=-1.0,
|
|
||||||
no_speech_threshold=0.6,
|
|
||||||
compression_ratio_threshold=2.4,
|
|
||||||
condition_on_previous_text=True,
|
|
||||||
temperature=[0.0, 1.0 + 1e-6, 0.2],
|
|
||||||
suppress_tokens=[],
|
|
||||||
word_timestamps=True,
|
|
||||||
print_colors=False,
|
|
||||||
prepend_punctuations="\"'“¿([{-",
|
|
||||||
append_punctuations="\"'.。,,!!??::”)]}、",
|
|
||||||
vad_filter=False,
|
|
||||||
vad_threshold=None,
|
|
||||||
vad_min_speech_duration_ms=None,
|
|
||||||
vad_max_speech_duration_s=None,
|
|
||||||
vad_min_silence_duration_ms=None,
|
|
||||||
initial_prompt=initial_prompt,
|
|
||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
no_repeat_ngram_size=0,
|
no_repeat_ngram_size=0,
|
||||||
|
temperature=[0.0, 1.0 + 1e-6, 0.2],
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
log_prob_threshold=-1.0,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
condition_on_previous_text=True,
|
||||||
prompt_reset_on_temperature=False,
|
prompt_reset_on_temperature=False,
|
||||||
|
initial_prompt=initial_prompt,
|
||||||
suppress_blank=False,
|
suppress_blank=False,
|
||||||
|
suppress_tokens=[],
|
||||||
|
word_timestamps=True,
|
||||||
|
prepend_punctuations="\"'“¿([{-",
|
||||||
|
append_punctuations="\"'.。,,!!??::”)]}、",
|
||||||
|
vad_filter=vad_filter,
|
||||||
|
vad_parameters=None,
|
||||||
)
|
)
|
||||||
return options
|
print(
|
||||||
|
"Detected language '%s' with probability %f"
|
||||||
|
% (info.language, info.language_probability)
|
||||||
|
)
|
||||||
|
last_pos = 0
|
||||||
|
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
|
||||||
|
for segment in segments:
|
||||||
|
start, end, text = segment.start, segment.end, segment.text
|
||||||
|
pbar.update(end - last_pos)
|
||||||
|
last_pos = end
|
||||||
|
data = segment._asdict()
|
||||||
|
data["total"] = info.duration
|
||||||
|
data["text"] = ccc.convert(data["text"])
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/k6nele/status")
|
@app.websocket("/k6nele/status")
|
||||||
@@ -131,6 +176,7 @@ async def konele_ws(
|
|||||||
task: Literal["transcribe", "translate"] = "transcribe",
|
task: Literal["transcribe", "translate"] = "transcribe",
|
||||||
lang: str = "und",
|
lang: str = "und",
|
||||||
initial_prompt: str = "",
|
initial_prompt: str = "",
|
||||||
|
vad_filter: bool = False,
|
||||||
content_type: Annotated[str, Query(alias="content-type")] = "audio/x-raw",
|
content_type: Annotated[str, Query(alias="content-type")] = "audio/x-raw",
|
||||||
):
|
):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
@@ -169,18 +215,16 @@ async def konele_ws(
|
|||||||
|
|
||||||
file_obj.seek(0)
|
file_obj.seek(0)
|
||||||
|
|
||||||
options = get_options(initial_prompt=initial_prompt)
|
generator = stream_builder(
|
||||||
|
|
||||||
result = transcriber.inference(
|
|
||||||
audio=file_obj,
|
audio=file_obj,
|
||||||
task=task,
|
task=task,
|
||||||
language=lang if lang != "und" else None, # type: ignore
|
vad_filter=vad_filter,
|
||||||
verbose=False,
|
language=None if lang == "und" else lang,
|
||||||
live=False,
|
initial_prompt=initial_prompt,
|
||||||
options=options,
|
|
||||||
)
|
)
|
||||||
|
result = build_json_result(generator)
|
||||||
|
|
||||||
text = result.get("text", "")
|
text = result.get("text", "")
|
||||||
text = ccc.convert(text)
|
|
||||||
print("result", text)
|
print("result", text)
|
||||||
|
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
@@ -201,6 +245,7 @@ async def translateapi(
|
|||||||
task: Literal["transcribe", "translate"] = "transcribe",
|
task: Literal["transcribe", "translate"] = "transcribe",
|
||||||
lang: str = "und",
|
lang: str = "und",
|
||||||
initial_prompt: str = "",
|
initial_prompt: str = "",
|
||||||
|
vad_filter: bool = False,
|
||||||
):
|
):
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
print("downloading request file", content_type)
|
print("downloading request file", content_type)
|
||||||
@@ -234,18 +279,16 @@ async def translateapi(
|
|||||||
|
|
||||||
file_obj.seek(0)
|
file_obj.seek(0)
|
||||||
|
|
||||||
options = get_options(initial_prompt=initial_prompt)
|
generator = stream_builder(
|
||||||
|
|
||||||
result = transcriber.inference(
|
|
||||||
audio=file_obj,
|
audio=file_obj,
|
||||||
task=task,
|
task=task,
|
||||||
language=lang if lang != "und" else None, # type: ignore
|
vad_filter=vad_filter,
|
||||||
verbose=False,
|
language=None if lang == "und" else lang,
|
||||||
live=False,
|
initial_prompt=initial_prompt,
|
||||||
options=options,
|
|
||||||
)
|
)
|
||||||
|
result = build_json_result(generator)
|
||||||
|
|
||||||
text = result.get("text", "")
|
text = result.get("text", "")
|
||||||
text = ccc.convert(text)
|
|
||||||
print("result", text)
|
print("result", text)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -270,84 +313,31 @@ async def transcription(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# timestamp as filename, keep original extension
|
# timestamp as filename, keep original extension
|
||||||
options = get_options(initial_prompt=prompt)
|
generator = stream_builder(
|
||||||
|
audio=io.BytesIO(file.file.read()),
|
||||||
|
task=task,
|
||||||
|
vad_filter=vad_filter,
|
||||||
|
language=None if language == "und" else language,
|
||||||
|
)
|
||||||
|
|
||||||
# special function for streaming response (OpenAI API does not have this)
|
# special function for streaming response (OpenAI API does not have this)
|
||||||
if response_format == "stream":
|
if response_format == "stream":
|
||||||
|
|
||||||
def gen():
|
|
||||||
segments, info = transcriber.model.transcribe(
|
|
||||||
audio=io.BytesIO(file.file.read()),
|
|
||||||
language=None if language == "und" else language, # type: ignore
|
|
||||||
task=task,
|
|
||||||
beam_size=options.beam_size,
|
|
||||||
best_of=options.best_of,
|
|
||||||
patience=options.patience,
|
|
||||||
length_penalty=options.length_penalty,
|
|
||||||
repetition_penalty=options.repetition_penalty,
|
|
||||||
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
|
||||||
temperature=options.temperature,
|
|
||||||
compression_ratio_threshold=options.compression_ratio_threshold,
|
|
||||||
log_prob_threshold=options.log_prob_threshold,
|
|
||||||
no_speech_threshold=options.no_speech_threshold,
|
|
||||||
condition_on_previous_text=options.condition_on_previous_text,
|
|
||||||
prompt_reset_on_temperature=options.prompt_reset_on_temperature,
|
|
||||||
initial_prompt=options.initial_prompt,
|
|
||||||
suppress_blank=options.suppress_blank,
|
|
||||||
suppress_tokens=options.suppress_tokens,
|
|
||||||
word_timestamps=True
|
|
||||||
if options.print_colors
|
|
||||||
else options.word_timestamps,
|
|
||||||
prepend_punctuations=options.prepend_punctuations,
|
|
||||||
append_punctuations=options.append_punctuations,
|
|
||||||
vad_filter=vad_filter,
|
|
||||||
vad_parameters=None,
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"Detected language '%s' with probability %f"
|
|
||||||
% (info.language, info.language_probability)
|
|
||||||
)
|
|
||||||
last_pos = 0
|
|
||||||
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
|
|
||||||
for segment in segments:
|
|
||||||
start, end, text = segment.start, segment.end, segment.text
|
|
||||||
pbar.update(end - last_pos)
|
|
||||||
last_pos = end
|
|
||||||
data = segment._asdict()
|
|
||||||
data["total"] = info.duration
|
|
||||||
data["text"] = ccc.convert(data["text"])
|
|
||||||
yield "data: " + json.dumps(data, ensure_ascii=False) + "\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
gen(),
|
stream_writer(generator),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
|
elif response_format == "json":
|
||||||
result: Any = transcriber.inference(
|
return build_json_result(generator)
|
||||||
audio=io.BytesIO(file.file.read()),
|
|
||||||
task=task,
|
|
||||||
language=None if language == "und" else language, # type: ignore
|
|
||||||
verbose=False,
|
|
||||||
live=False,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response_format == "json":
|
|
||||||
return result
|
|
||||||
elif response_format == "text":
|
elif response_format == "text":
|
||||||
return Response(
|
return StreamingResponse(text_writer(generator), media_type="text/plain")
|
||||||
content="\n".join(s["text"] for s in result["segments"]),
|
|
||||||
media_type="plain/text",
|
|
||||||
)
|
|
||||||
elif response_format == "tsv":
|
elif response_format == "tsv":
|
||||||
return Response(content=generate_tsv(result), media_type="plain_text")
|
return StreamingResponse(tsv_writer(generator), media_type="text/plain")
|
||||||
elif response_format == "srt":
|
elif response_format == "srt":
|
||||||
return Response(content=generate_srt(result), media_type="plain_text")
|
return StreamingResponse(srt_writer(generator), media_type="text/plain")
|
||||||
elif response_format == "vtt":
|
elif response_format == "vtt":
|
||||||
return generate_vtt(result)
|
return StreamingResponse(vtt_writer(generator), media_type="text/plain")
|
||||||
|
|
||||||
return {"error": "Invalid response_format"}
|
raise HTTPException(400, "Invailed response_format")
|
||||||
|
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port)
|
uvicorn.run(app, host=args.host, port=args.port)
|
||||||
|
|||||||
Reference in New Issue
Block a user