This commit is contained in:
2023-10-17 20:52:18 +08:00
commit 046f4017d0
5 changed files with 344 additions and 0 deletions

247
whisper_fastapi.py Normal file
View File

@@ -0,0 +1,247 @@
import wave
import io
import hashlib
import argparse
import uvicorn
from typing import Any
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
from src.whisper_ctranslate2.writers import format_timestamp
import opencc
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int)
parser.add_argument("--model", default="large-v2", type=str)
parser.add_argument("--cache_dir", default=None, type=str)
args = parser.parse_args()
app = FastAPI()
ccc = opencc.OpenCC("t2s.json")
print("Loading model...")
transcriber = Transcribe(
model_path=args.model,
device="auto",
device_index=0,
compute_type="default",
threads=1,
cache_directory=args.cache_dir,
local_files_only=False,
)
print("Model loaded!")
# allow all cors
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def generate_tsv(result: dict[str, list[Any]]):
tsv = "start\tend\ttext\n"
for i, segment in enumerate(result["segments"]):
start_time = str(round(1000 * segment["start"]))
end_time = str(round(1000 * segment["end"]))
text = segment["text"]
tsv += f"{start_time}\t{end_time}\t{text}\n"
return tsv
def generate_srt(result: dict[str, list[Any]]):
srt = ""
for i, segment in enumerate(result["segments"], start=1):
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"]
srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
return srt
def generate_vtt(result: dict[str, list[Any]]):
vtt = "WEBVTT\n\n"
for segment in result["segments"]:
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"]
vtt += f"{start_time} --> {end_time}\n{text}\n\n"
return vtt
def get_options(*, initial_prompt=""):
options = TranscriptionOptions(
beam_size=5,
best_of=5,
patience=1.0,
length_penalty=1.0,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
compression_ratio_threshold=2.4,
condition_on_previous_text=True,
temperature=[0.0, 1.0 + 1e-6, 0.2],
suppress_tokens=[-1],
word_timestamps=True,
print_colors=False,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=False,
vad_threshold=None,
vad_min_speech_duration_ms=None,
vad_max_speech_duration_s=None,
vad_min_silence_duration_ms=None,
initial_prompt=initial_prompt,
repetition_penalty=1.0,
no_repeat_ngram_size=0,
prompt_reset_on_temperature=False,
suppress_blank=False,
)
return options
@app.websocket("/konele/ws")
async def konele_ws(
websocket: WebSocket,
lang: str = "und",
):
await websocket.accept()
print("WebSocket client connected, lang is", lang)
data = b""
while True:
try:
data += await websocket.receive_bytes()
print("Received data:", len(data), data[-10:])
if data[-3:] == b"EOS":
print("End of speech")
break
except:
break
md5 = hashlib.md5(data).hexdigest()
# create fake file for wave.open
file_obj = io.BytesIO()
buffer = wave.open(file_obj, "wb")
buffer.setnchannels(1)
buffer.setsampwidth(2)
buffer.setframerate(16000)
buffer.writeframes(data)
file_obj.seek(0)
options = get_options()
result = transcriber.inference(
audio=file_obj,
# Enter translate mode if target language is English
task="translate" if lang == "en-US" else "transcribe",
language=None, # type: ignore
verbose=False,
live=False,
options=options,
)
text = result.get("text", "")
text = ccc.convert(text)
print("result", text)
await websocket.send_json(
{
"status": 0,
"segment": 0,
"result": {"hypotheses": [{"transcript": text}], "final": True},
"id": md5,
}
)
await websocket.close()
@app.post("/konele/post")
async def translateapi(
request: Request,
lang: str = "und",
):
content_type = request.headers.get("Content-Type", "")
print("downloading request file", content_type)
splited = [i.strip() for i in content_type.split(",") if "=" in i]
info = {k: v for k, v in (i.split("=") for i in splited)}
print(info)
channels = int(info.get("channels", "1"))
rate = int(info.get("rate", "16000"))
body = await request.body()
md5 = hashlib.md5(body).hexdigest()
# create fake file for wave.open
file_obj = io.BytesIO()
buffer = wave.open(file_obj, "wb")
buffer.setnchannels(channels)
buffer.setsampwidth(2)
buffer.setframerate(rate)
buffer.writeframes(body)
file_obj.seek(0)
options = get_options()
result = transcriber.inference(
audio=file_obj,
# Enter translate mode if target language is English
task="translate" if lang == "en-US" else "transcribe",
language=None, # type: ignore
verbose=False,
live=False,
options=options,
)
text = result.get("text", "")
text = ccc.convert(text)
print("result", text)
return {
"status": 0,
"hypotheses": [{"utterance": text}],
"id": md5,
}
@app.post("/v1/audio/transcriptions")
async def transcription(
file: UploadFile = File(...),
prompt: str = Form(""),
response_type: str = Form("json"),
):
"""Transcription endpoint
User upload audio file in multipart/form-data format and receive transcription in response
"""
# timestamp as filename, keep original extension
options = get_options(initial_prompt=prompt)
result: Any = transcriber.inference(
audio=io.BytesIO(file.file.read()),
task="transcribe",
language=None, # type: ignore
verbose=False,
live=False,
options=options,
)
if response_type == "json":
return result
elif response_type == "text":
return result["text"].strip()
elif response_type == "tsv":
return generate_tsv(result)
elif response_type == "srt":
return generate_srt(result)
elif response_type == "vtt":
return generate_vtt(result)
return {"error": "Invalid response_type"}
uvicorn.run(app, host=args.host, port=args.port)