Compare commits

...

23 Commits

Author SHA1 Message Date
72a8c736e3 revert faster-whisper to v1.0.3 2024-09-12 01:29:01 +08:00
b4fb0f217b strip text on tsv and srt output 2024-09-04 18:03:01 +08:00
1a5dbc65e0 update faster-whisper to heimoshuiyu(prompt) patched version 2024-09-04 18:02:44 +08:00
ea8fc74ed2 add start-podman.sh 2024-09-04 17:45:59 +08:00
c6948654a4 add .dockerignore 2024-08-08 18:13:42 +08:00
ffefb2f09e Add new WebSocket and POST endpoints 2024-08-08 18:12:13 +08:00
1c8a685e9e fix: typo 2024-08-08 17:47:31 +08:00
ed1e51fefa add docker 2024-08-08 17:24:48 +08:00
042800721d Update model loading message 2024-08-08 17:20:52 +08:00
f71ef945db Remove print statements and unnecessary code 2024-08-04 16:59:16 +08:00
1c93201250 fix: build_json_result 2024-08-04 16:57:22 +08:00
2ecdc4e607 Fix text conversion for Chinese language 2024-08-04 16:51:17 +08:00
204ccb8f3d Revert "disable zh-hans to zh-cn convertion"
This reverts commit 6decabefae.
2024-08-04 16:39:20 +08:00
d86ed9be69 Refactor build_json_result and stream_builder functions 2024-07-11 22:51:12 +08:00
e8ae8bf9c5 Add repetition_penalty parameter to transcription endpoint 2024-07-11 21:48:46 +08:00
931b578899 Add repetition_penalty parameter to stream_builder function 2024-07-11 21:45:38 +08:00
4ed1c695fe fix "words" in json response 2024-07-11 21:14:07 +08:00
8b766d0ce1 Update model and add new arguments 2024-07-10 22:14:47 +08:00
b22fea55ac upgrade dependencies 2024-07-10 20:51:04 +08:00
6decabefae disable zh-hans to zh-cn convertion 2024-07-10 19:03:16 +08:00
bc5fdec819 expose transcript info from stream_builder 2024-07-10 18:49:09 +08:00
47f9e7e873 delete transcribe params to use faster-whisper's default options 2024-07-10 18:32:37 +08:00
f7b5e8dc69 type: import vad 2024-01-14 01:30:19 +08:00
7 changed files with 144 additions and 103 deletions

1
.dockerignore Normal file
View File

@@ -0,0 +1 @@
/venv

19
Dockerfile Normal file
View File

@@ -0,0 +1,19 @@
FROM docker.io/nvidia/cuda:12.0.0-cudnn8-runtime-ubuntu22.04
RUN apt-get update && \
apt-get install -y ffmpeg python3 python3-pip git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 5000
# 启动 whisper_fastapi.py
ENTRYPOINT ["python3", "whisper_fastapi.py"]

View File

@@ -4,4 +4,5 @@ uvicorn[standard]
whisper_ctranslate2
opencc
prometheus-fastapi-instrumentator
git+https://github.com/heimoshuiyu/faster-whisper@prompt
pydub

View File

@@ -1,48 +1,49 @@
annotated-types==0.6.0
anyio==3.7.1
av==10.0.0
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
coloredlogs==15.0.1
ctranslate2==3.21.0
fastapi==0.104.1
faster-whisper==0.9.0
filelock==3.13.1
flatbuffers==23.5.26
fsspec==2023.10.0
h11==0.14.0
httptools==0.6.1
huggingface-hub==0.17.3
humanfriendly==10.0
idna==3.4
mpmath==1.3.0
numpy==1.26.2
onnxruntime==1.16.2
OpenCC==1.1.7
packaging==23.2
prometheus-client==0.18.0
prometheus-fastapi-instrumentator==6.1.0
protobuf==4.25.0
pycparser==2.21
pydantic==2.5.0
pydantic_core==2.14.1
pydub==0.25.1
python-dotenv==1.0.0
python-multipart==0.0.6
PyYAML==6.0.1
requests==2.31.0
sniffio==1.3.0
sounddevice==0.4.6
starlette==0.27.0
sympy==1.12
tokenizers==0.14.1
tqdm==4.66.1
typing_extensions==4.8.0
urllib3==2.1.0
uvicorn==0.24.0.post1
uvloop==0.19.0
watchfiles==0.21.0
websockets==12.0
whisper-ctranslate2==0.3.2
annotated-types==0.7.0
anyio==4.4.0
av==12.3.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.3.2
click==8.1.7
coloredlogs==15.0.1
ctranslate2==4.4.0
exceptiongroup==1.2.2
fastapi==0.114.1
faster-whisper @ git+https://github.com/heimoshuiyu/faster-whisper@28a4d11a736d8cdeb4655ee5d7e4b4e7ae5ec8e0
filelock==3.16.0
flatbuffers==24.3.25
fsspec==2024.9.0
h11==0.14.0
httptools==0.6.1
huggingface-hub==0.24.6
humanfriendly==10.0
idna==3.8
mpmath==1.3.0
numpy==2.1.1
onnxruntime==1.19.2
OpenCC==1.1.9
packaging==24.1
prometheus-fastapi-instrumentator==7.0.0
prometheus_client==0.20.0
protobuf==5.28.0
pycparser==2.22
pydantic==2.9.1
pydantic_core==2.23.3
pydub==0.25.1
python-dotenv==1.0.1
python-multipart==0.0.9
PyYAML==6.0.2
requests==2.32.3
sniffio==1.3.1
sounddevice==0.5.0
starlette==0.38.5
sympy==1.13.2
tokenizers==0.20.0
tqdm==4.66.5
typing_extensions==4.12.2
urllib3==2.2.2
uvicorn==0.30.6
uvloop==0.20.0
watchfiles==0.24.0
websockets==13.0.1
whisper-ctranslate2==0.4.5

10
start-docker.sh Executable file
View File

@@ -0,0 +1,10 @@
#!/bin/bash
docker run -d --name whisper-fastapi \
--restart unless-stopped \
--name whisper-fastapi \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--gpus all \
-p 5000:5000 \
docker.io/heimoshuiyu/whisper-fastapi:latest \
--model large-v2

11
start-podman.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
podman run -d --name whisper-fastapi \
--restart unless-stopped \
--name whisper-fastapi \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--device nvidia.com/gpu=all --security-opt=label=disable \
--gpus all \
-p 5000:5000 \
docker.io/heimoshuiyu/whisper-fastapi:latest \
--model large-v2

View File

@@ -1,4 +1,3 @@
from faster_whisper import vad
import tqdm
import json
from fastapi.responses import StreamingResponse
@@ -8,7 +7,7 @@ import io
import hashlib
import argparse
import uvicorn
from typing import Annotated, Any, BinaryIO, Literal, Generator
from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
from fastapi import (
File,
HTTPException,
@@ -22,32 +21,35 @@ from fastapi import (
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp
from faster_whisper.transcribe import Segment, TranscriptionInfo
import opencc
from prometheus_fastapi_instrumentator import Instrumentator
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int)
parser.add_argument("--model", default="large-v2", type=str)
parser.add_argument("--model", default="large-v3", type=str)
parser.add_argument("--device", default="auto", type=str)
parser.add_argument("--cache_dir", default=None, type=str)
parser.add_argument("--local_files_only", default=False, type=bool)
parser.add_argument("--threads", default=4, type=int)
args = parser.parse_args()
app = FastAPI()
# Instrument your app with default metrics and expose the metrics
Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
ccc = opencc.OpenCC("t2s.json")
print("Loading model...")
print(f"Loading model to device {args.device}...")
transcriber = Transcribe(
model_path=args.model,
device=args.device,
device_index=0,
compute_type="default",
threads=1,
threads=args.threads,
cache_directory=args.cache_dir,
local_files_only=False,
local_files_only=args.local_files_only,
)
print("Model loaded!")
print(f"Model loaded to device {transcriber.model.model.device}")
# allow all cors
@@ -76,7 +78,7 @@ def tsv_writer(generator: Generator[dict[str, Any], Any, None]):
for i, segment in enumerate(generator):
start_time = str(round(1000 * segment["start"]))
end_time = str(round(1000 * segment["end"]))
text = segment["text"]
text = segment["text"].strip()
yield f"{start_time}\t{end_time}\t{text}\n"
@@ -88,7 +90,7 @@ def srt_writer(generator: Generator[dict[str, Any], Any, None]):
end_time = format_timestamp(
segment["end"], decimal_marker=",", always_include_hours=True
)
text = segment["text"]
text = segment["text"].strip()
yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
@@ -97,15 +99,16 @@ def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
for i, segment in enumerate(generator):
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"]
text = segment["text"].strip()
yield f"{start_time} --> {end_time}\n{text}\n\n"
def build_json_result(
generator: Generator[dict[str, Any], Any, None]
generator: Iterable[Segment],
info: dict,
) -> dict[str, Any]:
segments = [i for i in generator]
return {
return info | {
"text": "\n".join(i["text"] for i in segments),
"segments": segments,
}
@@ -117,50 +120,47 @@ def stream_builder(
vad_filter: bool,
language: str | None,
initial_prompt: str = "",
):
repetition_penalty: float = 1.0,
) -> Tuple[Iterable[dict], dict]:
segments, info = transcriber.model.transcribe(
audio=audio,
language=language,
task=task,
beam_size=5,
best_of=5,
patience=1.0,
length_penalty=-1.0,
repetition_penalty=1.0,
no_repeat_ngram_size=0,
temperature=[0.0, 1.0 + 1e-6, 0.2],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
prompt_reset_on_temperature=False,
initial_prompt=initial_prompt,
suppress_blank=False,
suppress_tokens=[],
word_timestamps=True,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=vad_filter,
vad_parameters=None,
repetition_penalty=repetition_penalty,
)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
data["total"] = info.duration
data["text"] = ccc.convert(data["text"])
yield data
def wrap():
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
data = segment._asdict()
if data.get('words') is not None:
data["words"] = [i._asdict() for i in data["words"]]
if info.language == "zh":
data["text"] = ccc.convert(data["text"])
yield data
info_dict = info._asdict()
if info_dict['transcription_options'] is not None:
info_dict['transcription_options'] = info_dict['transcription_options']._asdict()
if info_dict['vad_options'] is not None:
info_dict['vad_options'] = info_dict['vad_options']._asdict()
return wrap(), info_dict
@app.websocket("/k6nele/status")
@app.websocket("/konele/status")
@app.websocket("/v1/k6nele/status")
@app.websocket("/v1/konele/status")
async def konele_status(
websocket: WebSocket,
):
@@ -171,6 +171,8 @@ async def konele_status(
@app.websocket("/k6nele/ws")
@app.websocket("/konele/ws")
@app.websocket("/v1/k6nele/ws")
@app.websocket("/v1/konele/ws")
async def konele_ws(
websocket: WebSocket,
task: Literal["transcribe", "translate"] = "transcribe",
@@ -184,15 +186,11 @@ async def konele_ws(
# convert lang code format (eg. en-US to en)
lang = lang.split("-")[0]
print("WebSocket client connected, lang is", lang)
print("content-type is", content_type)
data = b""
while True:
try:
data += await websocket.receive_bytes()
print("Received data:", len(data), data[-10:])
if data[-3:] == b"EOS":
print("End of speech")
break
except:
break
@@ -215,17 +213,16 @@ async def konele_ws(
file_obj.seek(0)
generator = stream_builder(
generator, info = stream_builder(
audio=file_obj,
task=task,
vad_filter=vad_filter,
language=None if lang == "und" else lang,
initial_prompt=initial_prompt,
)
result = build_json_result(generator)
result = build_json_result(generator, info)
text = result.get("text", "")
print("result", text)
await websocket.send_json(
{
@@ -240,6 +237,8 @@ async def konele_ws(
@app.post("/k6nele/post")
@app.post("/konele/post")
@app.post("/v1/k6nele/post")
@app.post("/v1/konele/post")
async def translateapi(
request: Request,
task: Literal["transcribe", "translate"] = "transcribe",
@@ -248,14 +247,12 @@ async def translateapi(
vad_filter: bool = False,
):
content_type = request.headers.get("Content-Type", "")
print("downloading request file", content_type)
# convert lang code format (eg. en-US to en)
lang = lang.split("-")[0]
splited = [i.strip() for i in content_type.split(",") if "=" in i]
info = {k: v for k, v in (i.split("=") for i in splited)}
print(info)
channels = int(info.get("channels", "1"))
rate = int(info.get("rate", "16000"))
@@ -279,17 +276,16 @@ async def translateapi(
file_obj.seek(0)
generator = stream_builder(
generator, info = stream_builder(
audio=file_obj,
task=task,
vad_filter=vad_filter,
language=None if lang == "und" else lang,
initial_prompt=initial_prompt,
)
result = build_json_result(generator)
result = build_json_result(generator, info)
text = result.get("text", "")
print("result", text)
return {
"status": 0,
@@ -306,6 +302,7 @@ async def transcription(
task: str = Form("transcribe"),
language: str = Form("und"),
vad_filter: bool = Form(False),
repetition_penalty: float = Form(1.0),
):
"""Transcription endpoint
@@ -313,11 +310,12 @@ async def transcription(
"""
# timestamp as filename, keep original extension
generator = stream_builder(
generator, info = stream_builder(
audio=io.BytesIO(file.file.read()),
task=task,
vad_filter=vad_filter,
language=None if language == "und" else language,
repetition_penalty=repetition_penalty,
)
# special function for streaming response (OpenAI API does not have this)
@@ -327,7 +325,7 @@ async def transcription(
media_type="text/event-stream",
)
elif response_format == "json":
return build_json_result(generator)
return build_json_result(generator, info)
elif response_format == "text":
return StreamingResponse(text_writer(generator), media_type="text/plain")
elif response_format == "tsv":