expose transcript info from stream_builder
This commit is contained in:
@@ -7,7 +7,7 @@ import io
|
||||
import hashlib
|
||||
import argparse
|
||||
import uvicorn
|
||||
from typing import Annotated, Any, BinaryIO, Literal, Generator
|
||||
from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
|
||||
from fastapi import (
|
||||
File,
|
||||
HTTPException,
|
||||
@@ -21,6 +21,7 @@ from fastapi import (
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
|
||||
from src.whisper_ctranslate2.writers import format_timestamp
|
||||
from faster_whisper.transcribe import Segment, TranscriptionInfo
|
||||
import opencc
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
@@ -116,7 +117,7 @@ def stream_builder(
|
||||
vad_filter: bool,
|
||||
language: str | None,
|
||||
initial_prompt: str = "",
|
||||
):
|
||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||
segments, info = transcriber.model.transcribe(
|
||||
audio=audio,
|
||||
language=language,
|
||||
@@ -127,16 +128,19 @@ def stream_builder(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
last_pos = 0
|
||||
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
|
||||
for segment in segments:
|
||||
start, end, text = segment.start, segment.end, segment.text
|
||||
pbar.update(end - last_pos)
|
||||
last_pos = end
|
||||
data = segment._asdict()
|
||||
data["total"] = info.duration
|
||||
data["text"] = ccc.convert(data["text"])
|
||||
yield data
|
||||
def wrap():
|
||||
last_pos = 0
|
||||
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
|
||||
for segment in segments:
|
||||
start, end, text = segment.start, segment.end, segment.text
|
||||
pbar.update(end - last_pos)
|
||||
last_pos = end
|
||||
data = segment._asdict()
|
||||
data["total"] = info.duration
|
||||
data["text"] = ccc.convert(data["text"])
|
||||
yield data
|
||||
|
||||
return wrap(), info
|
||||
|
||||
|
||||
@app.websocket("/k6nele/status")
|
||||
@@ -293,7 +297,7 @@ async def transcription(
|
||||
"""
|
||||
|
||||
# timestamp as filename, keep original extension
|
||||
generator = stream_builder(
|
||||
generator, info = stream_builder(
|
||||
audio=io.BytesIO(file.file.read()),
|
||||
task=task,
|
||||
vad_filter=vad_filter,
|
||||
|
||||
Reference in New Issue
Block a user