Compare commits

..

10 Commits

2 changed files with 74 additions and 78 deletions

View File

@@ -1,48 +1,48 @@
annotated-types==0.6.0 annotated-types==0.7.0
anyio==3.7.1 anyio==4.4.0
av==10.0.0 av==12.2.0
certifi==2023.7.22 certifi==2024.7.4
cffi==1.16.0 cffi==1.16.0
charset-normalizer==3.3.2 charset-normalizer==3.3.2
click==8.1.7 click==8.1.7
coloredlogs==15.0.1 coloredlogs==15.0.1
ctranslate2==3.21.0 ctranslate2==4.3.1
fastapi==0.104.1 fastapi==0.111.0
faster-whisper==0.9.0 faster-whisper==1.0.3
filelock==3.13.1 filelock==3.15.4
flatbuffers==23.5.26 flatbuffers==24.3.25
fsspec==2023.10.0 fsspec==2024.6.1
h11==0.14.0 h11==0.14.0
httptools==0.6.1 httptools==0.6.1
huggingface-hub==0.17.3 huggingface-hub==0.23.4
humanfriendly==10.0 humanfriendly==10.0
idna==3.4 idna==3.7
mpmath==1.3.0 mpmath==1.3.0
numpy==1.26.2 numpy==1.26.4
onnxruntime==1.16.2 onnxruntime==1.18.1
OpenCC==1.1.7 OpenCC==1.1.7
packaging==23.2 packaging==24.1
prometheus-client==0.18.0 prometheus-client==0.18.0
prometheus-fastapi-instrumentator==6.1.0 prometheus-fastapi-instrumentator==7.0.0
protobuf==4.25.0 protobuf==5.27.2
pycparser==2.21 pycparser==2.22
pydantic==2.5.0 pydantic==2.8.2
pydantic_core==2.14.1 pydantic_core==2.20.1
pydub==0.25.1 pydub==0.25.1
python-dotenv==1.0.0 python-dotenv==1.0.1
python-multipart==0.0.6 python-multipart==0.0.9
PyYAML==6.0.1 PyYAML==6.0.1
requests==2.31.0 requests==2.32.3
sniffio==1.3.0 sniffio==1.3.1
sounddevice==0.4.6 sounddevice==0.4.7
starlette==0.27.0 starlette==0.37.2
sympy==1.12 sympy==1.12.1
tokenizers==0.14.1 tokenizers==0.19.1
tqdm==4.66.1 tqdm==4.66.4
typing_extensions==4.8.0 typing_extensions==4.12.2
urllib3==2.1.0 urllib3==2.2.2
uvicorn==0.24.0.post1 uvicorn==0.30.1
uvloop==0.19.0 uvloop==0.19.0
watchfiles==0.21.0 watchfiles==0.22.0
websockets==12.0 websockets==12.0
whisper-ctranslate2==0.3.2 whisper-ctranslate2==0.4.5

View File

@@ -1,4 +1,3 @@
from faster_whisper import vad
import tqdm import tqdm
import json import json
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -8,7 +7,7 @@ import io
import hashlib import hashlib
import argparse import argparse
import uvicorn import uvicorn
from typing import Annotated, Any, BinaryIO, Literal, Generator from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
from fastapi import ( from fastapi import (
File, File,
HTTPException, HTTPException,
@@ -22,20 +21,21 @@ from fastapi import (
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp from src.whisper_ctranslate2.writers import format_timestamp
import opencc from faster_whisper.transcribe import Segment, TranscriptionInfo
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str) parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int) parser.add_argument("--port", default=5000, type=int)
parser.add_argument("--model", default="large-v2", type=str) parser.add_argument("--model", default="large-v3", type=str)
parser.add_argument("--device", default="auto", type=str) parser.add_argument("--device", default="auto", type=str)
parser.add_argument("--cache_dir", default=None, type=str) parser.add_argument("--cache_dir", default=None, type=str)
parser.add_argument("--local_files_only", default=False, type=bool)
parser.add_argument("--threads", default=4, type=int)
args = parser.parse_args() args = parser.parse_args()
app = FastAPI() app = FastAPI()
# Instrument your app with default metrics and expose the metrics # Instrument your app with default metrics and expose the metrics
Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics") Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
ccc = opencc.OpenCC("t2s.json")
print("Loading model...") print("Loading model...")
transcriber = Transcribe( transcriber = Transcribe(
@@ -43,9 +43,9 @@ transcriber = Transcribe(
device=args.device, device=args.device,
device_index=0, device_index=0,
compute_type="default", compute_type="default",
threads=1, threads=args.threads,
cache_directory=args.cache_dir, cache_directory=args.cache_dir,
local_files_only=False, local_files_only=args.local_files_only,
) )
print("Model loaded!") print("Model loaded!")
@@ -102,10 +102,11 @@ def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
def build_json_result( def build_json_result(
generator: Generator[dict[str, Any], Any, None] generator: Iterable[Segment],
info: dict,
) -> dict[str, Any]: ) -> dict[str, Any]:
segments = [i for i in generator] segments = [i for i in generator]
return { return info | {
"text": "\n".join(i["text"] for i in segments), "text": "\n".join(i["text"] for i in segments),
"segments": segments, "segments": segments,
} }
@@ -117,46 +118,39 @@ def stream_builder(
vad_filter: bool, vad_filter: bool,
language: str | None, language: str | None,
initial_prompt: str = "", initial_prompt: str = "",
): repetition_penalty: float = 1.0,
) -> Tuple[Iterable[dict], dict]:
segments, info = transcriber.model.transcribe( segments, info = transcriber.model.transcribe(
audio=audio, audio=audio,
language=language, language=language,
task=task, task=task,
beam_size=5,
best_of=5,
patience=1.0,
length_penalty=-1.0,
repetition_penalty=1.0,
no_repeat_ngram_size=0,
temperature=[0.0, 1.0 + 1e-6, 0.2],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
prompt_reset_on_temperature=False,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
suppress_blank=False,
suppress_tokens=[],
word_timestamps=True, word_timestamps=True,
prepend_punctuations="\"'“¿([{-", repetition_penalty=repetition_penalty,
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=vad_filter,
vad_parameters=None,
) )
print( print(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
% (info.language, info.language_probability) % (info.language, info.language_probability)
) )
last_pos = 0 def wrap():
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar: last_pos = 0
for segment in segments: with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
start, end, text = segment.start, segment.end, segment.text for segment in segments:
pbar.update(end - last_pos) start, end, text = segment.start, segment.end, segment.text
last_pos = end pbar.update(end - last_pos)
data = segment._asdict() last_pos = end
data["total"] = info.duration data = segment._asdict()
data["text"] = ccc.convert(data["text"]) if data.get('words') is not None:
yield data data["words"] = [i._asdict() for i in data["words"]]
yield data
info_dict = info._asdict()
if info_dict['transcription_options'] is not None:
info_dict['transcription_options'] = info_dict['transcription_options']._asdict()
if info_dict['vad_options'] is not None:
info_dict['vad_options'] = info_dict['vad_options']._asdict()
return wrap(), info_dict
@app.websocket("/k6nele/status") @app.websocket("/k6nele/status")
@@ -215,14 +209,14 @@ async def konele_ws(
file_obj.seek(0) file_obj.seek(0)
generator = stream_builder( generator, info = stream_builder(
audio=file_obj, audio=file_obj,
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
language=None if lang == "und" else lang, language=None if lang == "und" else lang,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
) )
result = build_json_result(generator) result = build_json_result(generator, info)
text = result.get("text", "") text = result.get("text", "")
print("result", text) print("result", text)
@@ -279,14 +273,14 @@ async def translateapi(
file_obj.seek(0) file_obj.seek(0)
generator = stream_builder( generator, info = stream_builder(
audio=file_obj, audio=file_obj,
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
language=None if lang == "und" else lang, language=None if lang == "und" else lang,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
) )
result = build_json_result(generator) result = build_json_result(generator, info)
text = result.get("text", "") text = result.get("text", "")
print("result", text) print("result", text)
@@ -306,6 +300,7 @@ async def transcription(
task: str = Form("transcribe"), task: str = Form("transcribe"),
language: str = Form("und"), language: str = Form("und"),
vad_filter: bool = Form(False), vad_filter: bool = Form(False),
repetition_penalty: float = Form(1.0),
): ):
"""Transcription endpoint """Transcription endpoint
@@ -313,11 +308,12 @@ async def transcription(
""" """
# timestamp as filename, keep original extension # timestamp as filename, keep original extension
generator = stream_builder( generator, info = stream_builder(
audio=io.BytesIO(file.file.read()), audio=io.BytesIO(file.file.read()),
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
language=None if language == "und" else language, language=None if language == "und" else language,
repetition_penalty=repetition_penalty,
) )
# special function for streaming response (OpenAI API does not have this) # special function for streaming response (OpenAI API does not have this)
@@ -327,7 +323,7 @@ async def transcription(
media_type="text/event-stream", media_type="text/event-stream",
) )
elif response_format == "json": elif response_format == "json":
return build_json_result(generator) return build_json_result(generator, info)
elif response_format == "text": elif response_format == "text":
return StreamingResponse(text_writer(generator), media_type="text/plain") return StreamingResponse(text_writer(generator), media_type="text/plain")
elif response_format == "tsv": elif response_format == "tsv":