upgrade with faster-whisper upstream

This commit is contained in:
2024-11-15 01:04:49 +08:00
parent 8ae81a124d
commit 4a5ba38f5e
4 changed files with 61 additions and 71 deletions

View File

@@ -1,3 +1,5 @@
import dataclasses
import faster_whisper
import tqdm
import json
from fastapi.responses import StreamingResponse
@@ -7,7 +9,7 @@ import io
import hashlib
import argparse
import uvicorn
from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable, Union
from fastapi import (
File,
HTTPException,
@@ -40,16 +42,13 @@ Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
ccc = opencc.OpenCC("t2s.json")
print(f"Loading model to device {args.device}...")
transcriber = Transcribe(
model_path=args.model,
model = faster_whisper.WhisperModel(
model_size_or_path=args.model,
device=args.device,
device_index=0,
compute_type="default",
threads=args.threads,
cache_directory=args.cache_dir,
cpu_threads=args.threads,
local_files_only=args.local_files_only,
)
print(f"Model loaded to device {transcriber.model.model.device}")
print(f"Model loaded to device {model.model.device}")
# allow all cors
@@ -62,56 +61,62 @@ app.add_middleware(
)
def stream_writer(generator: Generator[dict[str, Any], Any, None]):
def stream_writer(generator: Generator[Segment, Any, None]):
for segment in generator:
yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
yield "data: [DONE]\n\n"
def text_writer(generator: Generator[dict[str, Any], Any, None]):
def text_writer(generator: Generator[Segment, Any, None]):
for segment in generator:
yield segment["text"].strip() + "\n"
yield segment.text.strip() + "\n"
def tsv_writer(generator: Generator[dict[str, Any], Any, None]):
def tsv_writer(generator: Generator[Segment, Any, None]):
yield "start\tend\ttext\n"
for i, segment in enumerate(generator):
start_time = str(round(1000 * segment["start"]))
end_time = str(round(1000 * segment["end"]))
text = segment["text"].strip()
start_time = str(round(1000 * segment.start))
end_time = str(round(1000 * segment.end))
text = segment.text.strip()
yield f"{start_time}\t{end_time}\t{text}\n"
def srt_writer(generator: Generator[dict[str, Any], Any, None]):
def srt_writer(generator: Generator[Segment, Any, None]):
for i, segment in enumerate(generator):
start_time = format_timestamp(
segment["start"], decimal_marker=",", always_include_hours=True
segment.start, decimal_marker=",", always_include_hours=True
)
end_time = format_timestamp(
segment["end"], decimal_marker=",", always_include_hours=True
segment.end, decimal_marker=",", always_include_hours=True
)
text = segment["text"].strip()
text = segment.text.strip()
yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
def vtt_writer(generator: Generator[Segment, Any, None]):
yield "WEBVTT\n\n"
for i, segment in enumerate(generator):
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"].strip()
start_time = format_timestamp(segment.start)
end_time = format_timestamp(segment.end)
text = segment.text.strip()
yield f"{start_time} --> {end_time}\n{text}\n\n"
@dataclasses.dataclass
class JsonResult(TranscriptionInfo):
segments: list[Segment]
text: str
def build_json_result(
generator: Iterable[dict],
info: dict,
) -> dict[str, Any]:
generator: Iterable[Segment],
info: TranscriptionInfo,
) -> JsonResult:
segments = [i for i in generator]
return info | {
"text": "\n".join(i["text"] for i in segments),
"segments": segments,
}
return JsonResult(
text="\n".join(i.text for i in segments),
segments=segments,
**dataclasses.asdict(info)
)
def stream_builder(
@@ -121,8 +126,8 @@ def stream_builder(
language: str | None,
initial_prompt: str = "",
repetition_penalty: float = 1.0,
) -> Tuple[Generator[dict, None, None], dict]:
segments, info = transcriber.model.transcribe(
) -> Tuple[Generator[Segment, None, None], TranscriptionInfo]:
segments, info = model.transcribe(
audio=audio,
language=language,
task=task,
@@ -142,20 +147,9 @@ def stream_builder(
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
if data.get('words') is not None:
data["words"] = [i._asdict() for i in data["words"]]
if info.language == "zh":
data["text"] = ccc.convert(data["text"])
yield data
yield segment
info_dict = info._asdict()
if info_dict['transcription_options'] is not None:
info_dict['transcription_options'] = info_dict['transcription_options']._asdict()
if info_dict['vad_options'] is not None:
info_dict['vad_options'] = info_dict['vad_options']._asdict()
return wrap(), info_dict
return wrap(), info
@app.websocket("/k6nele/status")
@@ -223,13 +217,11 @@ async def konele_ws(
)
result = build_json_result(generator, info)
text = result.get("text", "")
await websocket.send_json(
{
"status": 0,
"segment": 0,
"result": {"hypotheses": [{"transcript": text}], "final": True},
"result": {"hypotheses": [{"transcript": result.text}], "final": True},
"id": md5,
}
)
@@ -286,17 +278,15 @@ async def translateapi(
)
result = build_json_result(generator, info)
text = result.get("text", "")
return {
"status": 0,
"hypotheses": [{"utterance": text}],
"hypotheses": [{"utterance": result.text}],
"id": md5,
}
@app.post("/v1/audio/transcriptions")
@app.post("/v1/audio/translations")
@app.post("/v1/audio/transcriptions", response_model=Union[JsonResult, str])
@app.post("/v1/audio/translations", response_model=Union[JsonResult, str])
async def transcription(
request: Request,
file: UploadFile = File(...),