expose transcript info from stream_builder

This commit is contained in:
2024-07-10 18:49:09 +08:00
parent 47f9e7e873
commit bc5fdec819

View File

@@ -7,7 +7,7 @@ import io
import hashlib import hashlib
import argparse import argparse
import uvicorn import uvicorn
from typing import Annotated, Any, BinaryIO, Literal, Generator from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
from fastapi import ( from fastapi import (
File, File,
HTTPException, HTTPException,
@@ -21,6 +21,7 @@ from fastapi import (
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp from src.whisper_ctranslate2.writers import format_timestamp
from faster_whisper.transcribe import Segment, TranscriptionInfo
import opencc import opencc
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
@@ -116,7 +117,7 @@ def stream_builder(
vad_filter: bool, vad_filter: bool,
language: str | None, language: str | None,
initial_prompt: str = "", initial_prompt: str = "",
): ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
segments, info = transcriber.model.transcribe( segments, info = transcriber.model.transcribe(
audio=audio, audio=audio,
language=language, language=language,
@@ -127,16 +128,19 @@ def stream_builder(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
% (info.language, info.language_probability) % (info.language, info.language_probability)
) )
last_pos = 0 def wrap():
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar: last_pos = 0
for segment in segments: with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
start, end, text = segment.start, segment.end, segment.text for segment in segments:
pbar.update(end - last_pos) start, end, text = segment.start, segment.end, segment.text
last_pos = end pbar.update(end - last_pos)
data = segment._asdict() last_pos = end
data["total"] = info.duration data = segment._asdict()
data["text"] = ccc.convert(data["text"]) data["total"] = info.duration
yield data data["text"] = ccc.convert(data["text"])
yield data
return wrap(), info
@app.websocket("/k6nele/status") @app.websocket("/k6nele/status")
@@ -293,7 +297,7 @@ async def transcription(
""" """
# timestamp as filename, keep original extension # timestamp as filename, keep original extension
generator = stream_builder( generator, info = stream_builder(
audio=io.BytesIO(file.file.read()), audio=io.BytesIO(file.file.read()),
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,