expose transcript info from stream_builder

This commit is contained in:
2024-07-10 18:49:09 +08:00
parent 47f9e7e873
commit bc5fdec819

View File

@@ -7,7 +7,7 @@ import io
import hashlib
import argparse
import uvicorn
from typing import Annotated, Any, BinaryIO, Literal, Generator
from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
from fastapi import (
File,
HTTPException,
@@ -21,6 +21,7 @@ from fastapi import (
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp
from faster_whisper.transcribe import Segment, TranscriptionInfo
import opencc
from prometheus_fastapi_instrumentator import Instrumentator
@@ -116,7 +117,7 @@ def stream_builder(
vad_filter: bool,
language: str | None,
initial_prompt: str = "",
):
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
segments, info = transcriber.model.transcribe(
audio=audio,
language=language,
@@ -127,16 +128,19 @@ def stream_builder(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
data["total"] = info.duration
data["text"] = ccc.convert(data["text"])
yield data
def wrap():
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
data["total"] = info.duration
data["text"] = ccc.convert(data["text"])
yield data
return wrap(), info
@app.websocket("/k6nele/status")
@@ -293,7 +297,7 @@ async def transcription(
"""
# timestamp as filename, keep original extension
generator = stream_builder(
generator, info = stream_builder(
audio=io.BytesIO(file.file.read()),
task=task,
vad_filter=vad_filter,