Files
whisper-fastapi/whisper_fastapi.py

251 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import wave
import io
import hashlib
import argparse
import uvicorn
from typing import Any
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket, Response
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
from src.whisper_ctranslate2.writers import format_timestamp
import opencc
from prometheus_fastapi_instrumentator import Instrumentator
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int)
parser.add_argument("--model", default="large-v2", type=str)
parser.add_argument("--cache_dir", default=None, type=str)
args = parser.parse_args()
app = FastAPI()
# Instrument your app with default metrics and expose the metrics
Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
ccc = opencc.OpenCC("t2s.json")
print("Loading model...")
transcriber = Transcribe(
model_path=args.model,
device="auto",
device_index=0,
compute_type="default",
threads=1,
cache_directory=args.cache_dir,
local_files_only=False,
)
print("Model loaded!")
# allow all cors
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def generate_tsv(result: dict[str, list[Any]]):
tsv = "start\tend\ttext\n"
for i, segment in enumerate(result["segments"]):
start_time = str(round(1000 * segment["start"]))
end_time = str(round(1000 * segment["end"]))
text = segment["text"]
tsv += f"{start_time}\t{end_time}\t{text}\n"
return tsv
def generate_srt(result: dict[str, list[Any]]):
srt = ""
for i, segment in enumerate(result["segments"], start=1):
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"]
srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
return srt
def generate_vtt(result: dict[str, list[Any]]):
vtt = "WEBVTT\n\n"
for segment in result["segments"]:
start_time = format_timestamp(segment["start"])
end_time = format_timestamp(segment["end"])
text = segment["text"]
vtt += f"{start_time} --> {end_time}\n{text}\n\n"
return vtt
def get_options(*, initial_prompt=""):
options = TranscriptionOptions(
beam_size=5,
best_of=5,
patience=1.0,
length_penalty=1.0,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
compression_ratio_threshold=2.4,
condition_on_previous_text=True,
temperature=[0.0, 1.0 + 1e-6, 0.2],
suppress_tokens=[-1],
word_timestamps=True,
print_colors=False,
prepend_punctuations="\"'“¿([{-",
append_punctuations="\"'.。,!?::”)]}、",
vad_filter=False,
vad_threshold=None,
vad_min_speech_duration_ms=None,
vad_max_speech_duration_s=None,
vad_min_silence_duration_ms=None,
initial_prompt=initial_prompt,
repetition_penalty=1.0,
no_repeat_ngram_size=0,
prompt_reset_on_temperature=False,
suppress_blank=False,
)
return options
@app.websocket("/konele/ws")
async def konele_ws(
websocket: WebSocket,
lang: str = "und",
):
await websocket.accept()
print("WebSocket client connected, lang is", lang)
data = b""
while True:
try:
data += await websocket.receive_bytes()
print("Received data:", len(data), data[-10:])
if data[-3:] == b"EOS":
print("End of speech")
break
except:
break
md5 = hashlib.md5(data).hexdigest()
# create fake file for wave.open
file_obj = io.BytesIO()
buffer = wave.open(file_obj, "wb")
buffer.setnchannels(1)
buffer.setsampwidth(2)
buffer.setframerate(16000)
buffer.writeframes(data)
file_obj.seek(0)
options = get_options()
result = transcriber.inference(
audio=file_obj,
# Enter translate mode if target language is English
task="translate" if lang == "en-US" else "transcribe",
language=None, # type: ignore
verbose=False,
live=False,
options=options,
)
text = result.get("text", "")
text = ccc.convert(text)
print("result", text)
await websocket.send_json(
{
"status": 0,
"segment": 0,
"result": {"hypotheses": [{"transcript": text}], "final": True},
"id": md5,
}
)
await websocket.close()
@app.post("/konele/post")
async def translateapi(
request: Request,
lang: str = "und",
):
content_type = request.headers.get("Content-Type", "")
print("downloading request file", content_type)
splited = [i.strip() for i in content_type.split(",") if "=" in i]
info = {k: v for k, v in (i.split("=") for i in splited)}
print(info)
channels = int(info.get("channels", "1"))
rate = int(info.get("rate", "16000"))
body = await request.body()
md5 = hashlib.md5(body).hexdigest()
# create fake file for wave.open
file_obj = io.BytesIO()
buffer = wave.open(file_obj, "wb")
buffer.setnchannels(channels)
buffer.setsampwidth(2)
buffer.setframerate(rate)
buffer.writeframes(body)
file_obj.seek(0)
options = get_options()
result = transcriber.inference(
audio=file_obj,
# Enter translate mode if target language is English
task="translate" if lang == "en-US" else "transcribe",
language=None, # type: ignore
verbose=False,
live=False,
options=options,
)
text = result.get("text", "")
text = ccc.convert(text)
print("result", text)
return {
"status": 0,
"hypotheses": [{"utterance": text}],
"id": md5,
}
@app.post("/v1/audio/transcriptions")
async def transcription(
file: UploadFile = File(...),
prompt: str = Form(""),
response_format: str = Form("json"),
):
"""Transcription endpoint
User upload audio file in multipart/form-data format and receive transcription in response
"""
# timestamp as filename, keep original extension
options = get_options(initial_prompt=prompt)
result: Any = transcriber.inference(
audio=io.BytesIO(file.file.read()),
task="transcribe",
language=None, # type: ignore
verbose=False,
live=False,
options=options,
)
if response_format == "json":
return Response(content=result, media_type="application/json")
elif response_format == "text":
return Response(content='\n'.join(s['text'] for s in result['segments']), media_type="plain/text")
elif response_format == "tsv":
return Response(content=generate_tsv(result), media_type='plain_text')
elif response_format == "srt":
return Response(content=generate_srt(result), media_type='plain_text')
elif response_format == "vtt":
return generate_vtt(result)
return {"error": "Invalid response_format"}
uvicorn.run(app, host=args.host, port=args.port)