Compare commits

...

23 Commits

Author SHA1 Message Date
72a8c736e3 revert faster-whisper to v1.0.3 2024-09-12 01:29:01 +08:00
b4fb0f217b strip text on tsv and srt output 2024-09-04 18:03:01 +08:00
1a5dbc65e0 update faster-whisper to heimoshuiyu(prompt) patched version 2024-09-04 18:02:44 +08:00
ea8fc74ed2 add start-podman.sh 2024-09-04 17:45:59 +08:00
c6948654a4 add .dockerignore 2024-08-08 18:13:42 +08:00
ffefb2f09e Add new WebSocket and POST endpoints 2024-08-08 18:12:13 +08:00
1c8a685e9e fix: typo 2024-08-08 17:47:31 +08:00
ed1e51fefa add docker 2024-08-08 17:24:48 +08:00
042800721d Update model loading message 2024-08-08 17:20:52 +08:00
f71ef945db Remove print statements and unnecessary code 2024-08-04 16:59:16 +08:00
1c93201250 fix: build_json_result 2024-08-04 16:57:22 +08:00
2ecdc4e607 Fix text conversion for Chinese language 2024-08-04 16:51:17 +08:00
204ccb8f3d Revert "disable zh-hans to zh-cn convertion"
This reverts commit 6decabefae.
2024-08-04 16:39:20 +08:00
d86ed9be69 Refactor build_json_result and stream_builder functions 2024-07-11 22:51:12 +08:00
e8ae8bf9c5 Add repetition_penalty parameter to transcription endpoint 2024-07-11 21:48:46 +08:00
931b578899 Add repetition_penalty parameter to stream_builder function 2024-07-11 21:45:38 +08:00
4ed1c695fe fix "words" in json response 2024-07-11 21:14:07 +08:00
8b766d0ce1 Update model and add new arguments 2024-07-10 22:14:47 +08:00
b22fea55ac upgrade dependencies 2024-07-10 20:51:04 +08:00
6decabefae disable zh-hans to zh-cn convertion 2024-07-10 19:03:16 +08:00
bc5fdec819 expose transcript info from stream_builder 2024-07-10 18:49:09 +08:00
47f9e7e873 delete transcribe params to use faster-whisper's default options 2024-07-10 18:32:37 +08:00
f7b5e8dc69 type: import vad 2024-01-14 01:30:19 +08:00
7 changed files with 144 additions and 103 deletions

1
.dockerignore Normal file
View File

@@ -0,0 +1 @@
/venv

19
Dockerfile Normal file
View File

@@ -0,0 +1,19 @@
FROM docker.io/nvidia/cuda:12.0.0-cudnn8-runtime-ubuntu22.04
RUN apt-get update && \
apt-get install -y ffmpeg python3 python3-pip git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 5000
# 启动 whisper_fastapi.py
ENTRYPOINT ["python3", "whisper_fastapi.py"]

View File

@@ -4,4 +4,5 @@ uvicorn[standard]
whisper_ctranslate2 whisper_ctranslate2
opencc opencc
prometheus-fastapi-instrumentator prometheus-fastapi-instrumentator
git+https://github.com/heimoshuiyu/faster-whisper@prompt
pydub pydub

View File

@@ -1,48 +1,49 @@
annotated-types==0.6.0 annotated-types==0.7.0
anyio==3.7.1 anyio==4.4.0
av==10.0.0 av==12.3.0
certifi==2023.7.22 certifi==2024.8.30
cffi==1.16.0 cffi==1.17.1
charset-normalizer==3.3.2 charset-normalizer==3.3.2
click==8.1.7 click==8.1.7
coloredlogs==15.0.1 coloredlogs==15.0.1
ctranslate2==3.21.0 ctranslate2==4.4.0
fastapi==0.104.1 exceptiongroup==1.2.2
faster-whisper==0.9.0 fastapi==0.114.1
filelock==3.13.1 faster-whisper @ git+https://github.com/heimoshuiyu/faster-whisper@28a4d11a736d8cdeb4655ee5d7e4b4e7ae5ec8e0
flatbuffers==23.5.26 filelock==3.16.0
fsspec==2023.10.0 flatbuffers==24.3.25
h11==0.14.0 fsspec==2024.9.0
httptools==0.6.1 h11==0.14.0
huggingface-hub==0.17.3 httptools==0.6.1
humanfriendly==10.0 huggingface-hub==0.24.6
idna==3.4 humanfriendly==10.0
mpmath==1.3.0 idna==3.8
numpy==1.26.2 mpmath==1.3.0
onnxruntime==1.16.2 numpy==2.1.1
OpenCC==1.1.7 onnxruntime==1.19.2
packaging==23.2 OpenCC==1.1.9
prometheus-client==0.18.0 packaging==24.1
prometheus-fastapi-instrumentator==6.1.0 prometheus-fastapi-instrumentator==7.0.0
protobuf==4.25.0 prometheus_client==0.20.0
pycparser==2.21 protobuf==5.28.0
pydantic==2.5.0 pycparser==2.22
pydantic_core==2.14.1 pydantic==2.9.1
pydub==0.25.1 pydantic_core==2.23.3
python-dotenv==1.0.0 pydub==0.25.1
python-multipart==0.0.6 python-dotenv==1.0.1
PyYAML==6.0.1 python-multipart==0.0.9
requests==2.31.0 PyYAML==6.0.2
sniffio==1.3.0 requests==2.32.3
sounddevice==0.4.6 sniffio==1.3.1
starlette==0.27.0 sounddevice==0.5.0
sympy==1.12 starlette==0.38.5
tokenizers==0.14.1 sympy==1.13.2
tqdm==4.66.1 tokenizers==0.20.0
typing_extensions==4.8.0 tqdm==4.66.5
urllib3==2.1.0 typing_extensions==4.12.2
uvicorn==0.24.0.post1 urllib3==2.2.2
uvloop==0.19.0 uvicorn==0.30.6
watchfiles==0.21.0 uvloop==0.20.0
websockets==12.0 watchfiles==0.24.0
whisper-ctranslate2==0.3.2 websockets==13.0.1
whisper-ctranslate2==0.4.5

10
start-docker.sh Executable file
View File

@@ -0,0 +1,10 @@
#!/bin/bash
docker run -d --name whisper-fastapi \
--restart unless-stopped \
--name whisper-fastapi \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--gpus all \
-p 5000:5000 \
docker.io/heimoshuiyu/whisper-fastapi:latest \
--model large-v2

11
start-podman.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
podman run -d --name whisper-fastapi \
--restart unless-stopped \
--name whisper-fastapi \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--device nvidia.com/gpu=all --security-opt=label=disable \
--gpus all \
-p 5000:5000 \
docker.io/heimoshuiyu/whisper-fastapi:latest \
--model large-v2

View File

@@ -1,4 +1,3 @@
from faster_whisper import vad
import tqdm import tqdm
import json import json
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -8,7 +7,7 @@ import io
import hashlib import hashlib
import argparse import argparse
import uvicorn import uvicorn
from typing import Annotated, Any, BinaryIO, Literal, Generator from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
from fastapi import ( from fastapi import (
File, File,
HTTPException, HTTPException,
@@ -22,32 +21,35 @@ from fastapi import (
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp from src.whisper_ctranslate2.writers import format_timestamp
from faster_whisper.transcribe import Segment, TranscriptionInfo
import opencc import opencc
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str) parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int) parser.add_argument("--port", default=5000, type=int)
parser.add_argument("--model", default="large-v2", type=str) parser.add_argument("--model", default="large-v3", type=str)
parser.add_argument("--device", default="auto", type=str) parser.add_argument("--device", default="auto", type=str)
parser.add_argument("--cache_dir", default=None, type=str) parser.add_argument("--cache_dir", default=None, type=str)
parser.add_argument("--local_files_only", default=False, type=bool)
parser.add_argument("--threads", default=4, type=int)
args = parser.parse_args() args = parser.parse_args()
app = FastAPI() app = FastAPI()
# Instrument your app with default metrics and expose the metrics # Instrument your app with default metrics and expose the metrics
Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics") Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
ccc = opencc.OpenCC("t2s.json") ccc = opencc.OpenCC("t2s.json")
print("Loading model...") print(f"Loading model to device {args.device}...")
transcriber = Transcribe( transcriber = Transcribe(
model_path=args.model, model_path=args.model,
device=args.device, device=args.device,
device_index=0, device_index=0,
compute_type="default", compute_type="default",
threads=1, threads=args.threads,
cache_directory=args.cache_dir, cache_directory=args.cache_dir,
local_files_only=False, local_files_only=args.local_files_only,
) )
print("Model loaded!") print(f"Model loaded to device {transcriber.model.model.device}")
# allow all cors # allow all cors
@@ -76,7 +78,7 @@ def tsv_writer(generator: Generator[dict[str, Any], Any, None]):
for i, segment in enumerate(generator): for i, segment in enumerate(generator):
start_time = str(round(1000 * segment["start"])) start_time = str(round(1000 * segment["start"]))
end_time = str(round(1000 * segment["end"])) end_time = str(round(1000 * segment["end"]))
text = segment["text"] text = segment["text"].strip()
yield f"{start_time}\t{end_time}\t{text}\n" yield f"{start_time}\t{end_time}\t{text}\n"
@@ -88,7 +90,7 @@ def srt_writer(generator: Generator[dict[str, Any], Any, None]):
end_time = format_timestamp( end_time = format_timestamp(
segment["end"], decimal_marker=",", always_include_hours=True segment["end"], decimal_marker=",", always_include_hours=True
) )
text = segment["text"] text = segment["text"].strip()
yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n" yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
@@ -97,15 +99,16 @@ def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
for i, segment in enumerate(generator): for i, segment in enumerate(generator):
start_time = format_timestamp(segment["start"]) start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"]) end_time = format_timestamp(segment["end"])
text = segment["text"] text = segment["text"].strip()
yield f"{start_time} --> {end_time}\n{text}\n\n" yield f"{start_time} --> {end_time}\n{text}\n\n"
def build_json_result( def build_json_result(
generator: Generator[dict[str, Any], Any, None] generator: Iterable[Segment],
info: dict,
) -> dict[str, Any]: ) -> dict[str, Any]:
segments = [i for i in generator] segments = [i for i in generator]
return { return info | {
"text": "\n".join(i["text"] for i in segments), "text": "\n".join(i["text"] for i in segments),
"segments": segments, "segments": segments,
} }
@@ -117,50 +120,47 @@ def stream_builder(
vad_filter: bool, vad_filter: bool,
language: str | None, language: str | None,
initial_prompt: str = "", initial_prompt: str = "",
): repetition_penalty: float = 1.0,
) -> Tuple[Iterable[dict], dict]:
segments, info = transcriber.model.transcribe( segments, info = transcriber.model.transcribe(
audio=audio, audio=audio,
language=language, language=language,
task=task, task=task,
beam_size=5,
best_of=5,
patience=1.0,
length_penalty=-1.0,
repetition_penalty=1.0,
no_repeat_ngram_size=0,
temperature=[0.0, 1.0 + 1e-6, 0.2],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
prompt_reset_on_temperature=False,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
suppress_blank=False,
suppress_tokens=[],
word_timestamps=True, word_timestamps=True,
prepend_punctuations="\"'“¿([{-", repetition_penalty=repetition_penalty,
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=vad_filter,
vad_parameters=None,
) )
print( print(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
% (info.language, info.language_probability) % (info.language, info.language_probability)
) )
last_pos = 0 def wrap():
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar: last_pos = 0
for segment in segments: with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
start, end, text = segment.start, segment.end, segment.text for segment in segments:
pbar.update(end - last_pos) start, end, text = segment.start, segment.end, segment.text
last_pos = end pbar.update(end - last_pos)
data = segment._asdict() last_pos = end
data["total"] = info.duration data = segment._asdict()
data["text"] = ccc.convert(data["text"]) if data.get('words') is not None:
yield data data["words"] = [i._asdict() for i in data["words"]]
if info.language == "zh":
data["text"] = ccc.convert(data["text"])
yield data
info_dict = info._asdict()
if info_dict['transcription_options'] is not None:
info_dict['transcription_options'] = info_dict['transcription_options']._asdict()
if info_dict['vad_options'] is not None:
info_dict['vad_options'] = info_dict['vad_options']._asdict()
return wrap(), info_dict
@app.websocket("/k6nele/status") @app.websocket("/k6nele/status")
@app.websocket("/konele/status") @app.websocket("/konele/status")
@app.websocket("/v1/k6nele/status")
@app.websocket("/v1/konele/status")
async def konele_status( async def konele_status(
websocket: WebSocket, websocket: WebSocket,
): ):
@@ -171,6 +171,8 @@ async def konele_status(
@app.websocket("/k6nele/ws") @app.websocket("/k6nele/ws")
@app.websocket("/konele/ws") @app.websocket("/konele/ws")
@app.websocket("/v1/k6nele/ws")
@app.websocket("/v1/konele/ws")
async def konele_ws( async def konele_ws(
websocket: WebSocket, websocket: WebSocket,
task: Literal["transcribe", "translate"] = "transcribe", task: Literal["transcribe", "translate"] = "transcribe",
@@ -184,15 +186,11 @@ async def konele_ws(
# convert lang code format (eg. en-US to en) # convert lang code format (eg. en-US to en)
lang = lang.split("-")[0] lang = lang.split("-")[0]
print("WebSocket client connected, lang is", lang)
print("content-type is", content_type)
data = b"" data = b""
while True: while True:
try: try:
data += await websocket.receive_bytes() data += await websocket.receive_bytes()
print("Received data:", len(data), data[-10:])
if data[-3:] == b"EOS": if data[-3:] == b"EOS":
print("End of speech")
break break
except: except:
break break
@@ -215,17 +213,16 @@ async def konele_ws(
file_obj.seek(0) file_obj.seek(0)
generator = stream_builder( generator, info = stream_builder(
audio=file_obj, audio=file_obj,
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
language=None if lang == "und" else lang, language=None if lang == "und" else lang,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
) )
result = build_json_result(generator) result = build_json_result(generator, info)
text = result.get("text", "") text = result.get("text", "")
print("result", text)
await websocket.send_json( await websocket.send_json(
{ {
@@ -240,6 +237,8 @@ async def konele_ws(
@app.post("/k6nele/post") @app.post("/k6nele/post")
@app.post("/konele/post") @app.post("/konele/post")
@app.post("/v1/k6nele/post")
@app.post("/v1/konele/post")
async def translateapi( async def translateapi(
request: Request, request: Request,
task: Literal["transcribe", "translate"] = "transcribe", task: Literal["transcribe", "translate"] = "transcribe",
@@ -248,14 +247,12 @@ async def translateapi(
vad_filter: bool = False, vad_filter: bool = False,
): ):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
print("downloading request file", content_type)
# convert lang code format (eg. en-US to en) # convert lang code format (eg. en-US to en)
lang = lang.split("-")[0] lang = lang.split("-")[0]
splited = [i.strip() for i in content_type.split(",") if "=" in i] splited = [i.strip() for i in content_type.split(",") if "=" in i]
info = {k: v for k, v in (i.split("=") for i in splited)} info = {k: v for k, v in (i.split("=") for i in splited)}
print(info)
channels = int(info.get("channels", "1")) channels = int(info.get("channels", "1"))
rate = int(info.get("rate", "16000")) rate = int(info.get("rate", "16000"))
@@ -279,17 +276,16 @@ async def translateapi(
file_obj.seek(0) file_obj.seek(0)
generator = stream_builder( generator, info = stream_builder(
audio=file_obj, audio=file_obj,
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
language=None if lang == "und" else lang, language=None if lang == "und" else lang,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
) )
result = build_json_result(generator) result = build_json_result(generator, info)
text = result.get("text", "") text = result.get("text", "")
print("result", text)
return { return {
"status": 0, "status": 0,
@@ -306,6 +302,7 @@ async def transcription(
task: str = Form("transcribe"), task: str = Form("transcribe"),
language: str = Form("und"), language: str = Form("und"),
vad_filter: bool = Form(False), vad_filter: bool = Form(False),
repetition_penalty: float = Form(1.0),
): ):
"""Transcription endpoint """Transcription endpoint
@@ -313,11 +310,12 @@ async def transcription(
""" """
# timestamp as filename, keep original extension # timestamp as filename, keep original extension
generator = stream_builder( generator, info = stream_builder(
audio=io.BytesIO(file.file.read()), audio=io.BytesIO(file.file.read()),
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
language=None if language == "und" else language, language=None if language == "und" else language,
repetition_penalty=repetition_penalty,
) )
# special function for streaming response (OpenAI API does not have this) # special function for streaming response (OpenAI API does not have this)
@@ -327,7 +325,7 @@ async def transcription(
media_type="text/event-stream", media_type="text/event-stream",
) )
elif response_format == "json": elif response_format == "json":
return build_json_result(generator) return build_json_result(generator, info)
elif response_format == "text": elif response_format == "text":
return StreamingResponse(text_writer(generator), media_type="text/plain") return StreamingResponse(text_writer(generator), media_type="text/plain")
elif response_format == "tsv": elif response_format == "tsv":