expose transcript info from stream_builder
This commit is contained in:
@@ -7,7 +7,7 @@ import io
|
|||||||
import hashlib
|
import hashlib
|
||||||
import argparse
|
import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from typing import Annotated, Any, BinaryIO, Literal, Generator
|
from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
File,
|
File,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
@@ -21,6 +21,7 @@ from fastapi import (
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
|
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
|
||||||
from src.whisper_ctranslate2.writers import format_timestamp
|
from src.whisper_ctranslate2.writers import format_timestamp
|
||||||
|
from faster_whisper.transcribe import Segment, TranscriptionInfo
|
||||||
import opencc
|
import opencc
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
|
|
||||||
@@ -116,7 +117,7 @@ def stream_builder(
|
|||||||
vad_filter: bool,
|
vad_filter: bool,
|
||||||
language: str | None,
|
language: str | None,
|
||||||
initial_prompt: str = "",
|
initial_prompt: str = "",
|
||||||
):
|
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||||
segments, info = transcriber.model.transcribe(
|
segments, info = transcriber.model.transcribe(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
language=language,
|
language=language,
|
||||||
@@ -127,16 +128,19 @@ def stream_builder(
|
|||||||
"Detected language '%s' with probability %f"
|
"Detected language '%s' with probability %f"
|
||||||
% (info.language, info.language_probability)
|
% (info.language, info.language_probability)
|
||||||
)
|
)
|
||||||
last_pos = 0
|
def wrap():
|
||||||
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
|
last_pos = 0
|
||||||
for segment in segments:
|
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
|
||||||
start, end, text = segment.start, segment.end, segment.text
|
for segment in segments:
|
||||||
pbar.update(end - last_pos)
|
start, end, text = segment.start, segment.end, segment.text
|
||||||
last_pos = end
|
pbar.update(end - last_pos)
|
||||||
data = segment._asdict()
|
last_pos = end
|
||||||
data["total"] = info.duration
|
data = segment._asdict()
|
||||||
data["text"] = ccc.convert(data["text"])
|
data["total"] = info.duration
|
||||||
yield data
|
data["text"] = ccc.convert(data["text"])
|
||||||
|
yield data
|
||||||
|
|
||||||
|
return wrap(), info
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/k6nele/status")
|
@app.websocket("/k6nele/status")
|
||||||
@@ -293,7 +297,7 @@ async def transcription(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# timestamp as filename, keep original extension
|
# timestamp as filename, keep original extension
|
||||||
generator = stream_builder(
|
generator, info = stream_builder(
|
||||||
audio=io.BytesIO(file.file.read()),
|
audio=io.BytesIO(file.file.read()),
|
||||||
task=task,
|
task=task,
|
||||||
vad_filter=vad_filter,
|
vad_filter=vad_filter,
|
||||||
|
|||||||
Reference in New Issue
Block a user