From bc5fdec819a6909997055b7fa6b6ce8d9da2f8f1 Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Wed, 10 Jul 2024 18:49:09 +0800 Subject: [PATCH] expose transcript info from stream_builder --- whisper_fastapi.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/whisper_fastapi.py b/whisper_fastapi.py index e21ba92..40f7b7e 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -7,7 +7,7 @@ import io import hashlib import argparse import uvicorn -from typing import Annotated, Any, BinaryIO, Literal, Generator +from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable from fastapi import ( File, HTTPException, @@ -21,6 +21,7 @@ from fastapi import ( from fastapi.middleware.cors import CORSMiddleware from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe from src.whisper_ctranslate2.writers import format_timestamp +from faster_whisper.transcribe import Segment, TranscriptionInfo import opencc from prometheus_fastapi_instrumentator import Instrumentator @@ -116,7 +117,7 @@ def stream_builder( vad_filter: bool, language: str | None, initial_prompt: str = "", -): +) -> Tuple[Iterable[Segment], TranscriptionInfo]: segments, info = transcriber.model.transcribe( audio=audio, language=language, @@ -127,16 +128,19 @@ def stream_builder( "Detected language '%s' with probability %f" % (info.language, info.language_probability) ) - last_pos = 0 - with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar: - for segment in segments: - start, end, text = segment.start, segment.end, segment.text - pbar.update(end - last_pos) - last_pos = end - data = segment._asdict() - data["total"] = info.duration - data["text"] = ccc.convert(data["text"]) - yield data + def wrap(): + last_pos = 0 + with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar: + for segment in segments: + start, end, text = segment.start, segment.end, segment.text + pbar.update(end - last_pos) + last_pos = end + data = segment._asdict() + data["total"] = info.duration + data["text"] = ccc.convert(data["text"]) + yield data + + return wrap(), info @app.websocket("/k6nele/status") @@ -293,7 +297,7 @@ async def transcription( """ # timestamp as filename, keep original extension - generator = stream_builder( + generator, info = stream_builder( audio=io.BytesIO(file.file.read()), task=task, vad_filter=vad_filter,