init
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
/venv
|
||||||
52
README.md
Normal file
52
README.md
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Whisper-FastAPI
|
||||||
|
|
||||||
|
Whisper-FastAPI is a very simple Python FastAPI interface for konele and OpenAI services. It is based on the `faster-whisper` project and provides an API for konele-like interface, where translations and transcriptions can be obtained by connecting over websockets or POST requests.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Translation and Transcription**: The application provides an API for konele service, where translations and transcriptions can be obtained by connecting over websockets or POST requests.
|
||||||
|
- **Language Support**: If the target language is English, then the application will translate any source language to English.
|
||||||
|
- **Websocket and POST Method Support**: The project supports a websocket (`/konele/ws`) and a POST method to `/konele/post`.
|
||||||
|
- **Audio Transcriptions**: The `/v1/audio/transcriptions` endpoint allows users to upload an audio file and receive transcription in response, with an optional `response_type` parameter. The `response_type` can be 'json', 'text', 'tsv', 'srt', and 'vtt'.
|
||||||
|
- **Simplified Chinese**: The traditional Chinese will be automatically convert to simplified Chinese for konele using `opencc` library.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Konele Voice Typing
|
||||||
|
|
||||||
|
For konele voice typing, you can use either the websocket endpoint or the POST method endpoint.
|
||||||
|
|
||||||
|
- **Websocket**: Connect to the websocket at `/konele/ws` and send audio data. The server will respond with the transcription or translation.
|
||||||
|
- **POST Method**: Send a POST request to `/konele/post` with the audio data in the body. The server will respond with the transcription or translation.
|
||||||
|
|
||||||
|
You can also use the demo I have created to quickly test the effect at <https://yongyuancv.cn/konele/ws> and <https://yongyuancv.cn/konele/post>
|
||||||
|
|
||||||
|
### OpenAI Whisper Service
|
||||||
|
|
||||||
|
To use the service that matches the structure of the OpenAI Whisper service, send a POST request to `/v1/audio/transcriptions` with an audio file. The server will respond with the transcription in the format specified by the `response_type` parameter.
|
||||||
|
|
||||||
|
You can also use the demo I have created to quickly test the effect at <https://yongyuancv.cn/v1/audio/transcriptions>
|
||||||
|
|
||||||
|
My demo is using the large-v2 model on RTX3060.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
To run the application, you need to have Python installed on your machine. You can then clone the repository and install the required dependencies.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/heimoshuiyu/whisper-fastapi.git
|
||||||
|
cd whisper-fastapi
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then run the application using the following command: (model will be download from huggingface if not exists in cache dir)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --host 0.0.0.0 --port 5000 --model large-v2
|
||||||
|
```
|
||||||
|
|
||||||
|
This will start the application on `http://<your-ip-address>:5000`.
|
||||||
|
|
||||||
|
## Limitation
|
||||||
|
|
||||||
|
Defect: Due to the synchronous nature of inference, this API can actually only handle one request at a time.
|
||||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
python-multipart
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
whisper_ctranslate2
|
||||||
|
opencc
|
||||||
39
requirements_version.txt
Normal file
39
requirements_version.txt
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
annotated-types==0.6.0
|
||||||
|
anyio==3.7.1
|
||||||
|
av==10.0.0
|
||||||
|
certifi==2023.7.22
|
||||||
|
cffi==1.16.0
|
||||||
|
charset-normalizer==3.3.0
|
||||||
|
click==8.1.7
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
ctranslate2==3.20.0
|
||||||
|
fastapi==0.103.2
|
||||||
|
faster-whisper==0.9.0
|
||||||
|
filelock==3.12.4
|
||||||
|
flatbuffers==23.5.26
|
||||||
|
fsspec==2023.9.2
|
||||||
|
h11==0.14.0
|
||||||
|
huggingface-hub==0.17.3
|
||||||
|
humanfriendly==10.0
|
||||||
|
idna==3.4
|
||||||
|
mpmath==1.3.0
|
||||||
|
numpy==1.26.1
|
||||||
|
onnxruntime==1.16.1
|
||||||
|
OpenCC==1.1.7
|
||||||
|
packaging==23.2
|
||||||
|
protobuf==4.24.4
|
||||||
|
pycparser==2.21
|
||||||
|
pydantic==2.4.2
|
||||||
|
pydantic_core==2.10.1
|
||||||
|
PyYAML==6.0.1
|
||||||
|
requests==2.31.0
|
||||||
|
sniffio==1.3.0
|
||||||
|
sounddevice==0.4.6
|
||||||
|
starlette==0.27.0
|
||||||
|
sympy==1.12
|
||||||
|
tokenizers==0.14.1
|
||||||
|
tqdm==4.66.1
|
||||||
|
typing_extensions==4.8.0
|
||||||
|
urllib3==2.0.6
|
||||||
|
uvicorn==0.23.2
|
||||||
|
whisper-ctranslate2==0.3.2
|
||||||
247
whisper_fastapi.py
Normal file
247
whisper_fastapi.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
import wave
|
||||||
|
import io
|
||||||
|
import hashlib
|
||||||
|
import argparse
|
||||||
|
import uvicorn
|
||||||
|
from typing import Any
|
||||||
|
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
|
||||||
|
from src.whisper_ctranslate2.writers import format_timestamp
|
||||||
|
import opencc
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", default="0.0.0.0", type=str)
|
||||||
|
parser.add_argument("--port", default=5000, type=int)
|
||||||
|
parser.add_argument("--model", default="large-v2", type=str)
|
||||||
|
parser.add_argument("--cache_dir", default=None, type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
app = FastAPI()
|
||||||
|
ccc = opencc.OpenCC("t2s.json")
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
transcriber = Transcribe(
|
||||||
|
model_path=args.model,
|
||||||
|
device="auto",
|
||||||
|
device_index=0,
|
||||||
|
compute_type="default",
|
||||||
|
threads=1,
|
||||||
|
cache_directory=args.cache_dir,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
print("Model loaded!")
|
||||||
|
|
||||||
|
|
||||||
|
# allow all cors
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tsv(result: dict[str, list[Any]]):
|
||||||
|
tsv = "start\tend\ttext\n"
|
||||||
|
for i, segment in enumerate(result["segments"]):
|
||||||
|
start_time = str(round(1000 * segment["start"]))
|
||||||
|
end_time = str(round(1000 * segment["end"]))
|
||||||
|
text = segment["text"]
|
||||||
|
tsv += f"{start_time}\t{end_time}\t{text}\n"
|
||||||
|
return tsv
|
||||||
|
|
||||||
|
|
||||||
|
def generate_srt(result: dict[str, list[Any]]):
|
||||||
|
srt = ""
|
||||||
|
for i, segment in enumerate(result["segments"], start=1):
|
||||||
|
start_time = format_timestamp(segment["start"])
|
||||||
|
end_time = format_timestamp(segment["end"])
|
||||||
|
text = segment["text"]
|
||||||
|
srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
|
||||||
|
return srt
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vtt(result: dict[str, list[Any]]):
|
||||||
|
vtt = "WEBVTT\n\n"
|
||||||
|
for segment in result["segments"]:
|
||||||
|
start_time = format_timestamp(segment["start"])
|
||||||
|
end_time = format_timestamp(segment["end"])
|
||||||
|
text = segment["text"]
|
||||||
|
vtt += f"{start_time} --> {end_time}\n{text}\n\n"
|
||||||
|
return vtt
|
||||||
|
|
||||||
|
|
||||||
|
def get_options(*, initial_prompt=""):
|
||||||
|
options = TranscriptionOptions(
|
||||||
|
beam_size=5,
|
||||||
|
best_of=5,
|
||||||
|
patience=1.0,
|
||||||
|
length_penalty=1.0,
|
||||||
|
log_prob_threshold=-1.0,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
compression_ratio_threshold=2.4,
|
||||||
|
condition_on_previous_text=True,
|
||||||
|
temperature=[0.0, 1.0 + 1e-6, 0.2],
|
||||||
|
suppress_tokens=[-1],
|
||||||
|
word_timestamps=True,
|
||||||
|
print_colors=False,
|
||||||
|
prepend_punctuations="\"'“¿([{-",
|
||||||
|
append_punctuations="\"'.。,,!!??::”)]}、",
|
||||||
|
vad_filter=False,
|
||||||
|
vad_threshold=None,
|
||||||
|
vad_min_speech_duration_ms=None,
|
||||||
|
vad_max_speech_duration_s=None,
|
||||||
|
vad_min_silence_duration_ms=None,
|
||||||
|
initial_prompt=initial_prompt,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
no_repeat_ngram_size=0,
|
||||||
|
prompt_reset_on_temperature=False,
|
||||||
|
suppress_blank=False,
|
||||||
|
)
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/konele/ws")
|
||||||
|
async def konele_ws(
|
||||||
|
websocket: WebSocket,
|
||||||
|
lang: str = "und",
|
||||||
|
):
|
||||||
|
await websocket.accept()
|
||||||
|
print("WebSocket client connected, lang is", lang)
|
||||||
|
data = b""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data += await websocket.receive_bytes()
|
||||||
|
print("Received data:", len(data), data[-10:])
|
||||||
|
if data[-3:] == b"EOS":
|
||||||
|
print("End of speech")
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
break
|
||||||
|
|
||||||
|
md5 = hashlib.md5(data).hexdigest()
|
||||||
|
|
||||||
|
# create fake file for wave.open
|
||||||
|
file_obj = io.BytesIO()
|
||||||
|
|
||||||
|
buffer = wave.open(file_obj, "wb")
|
||||||
|
buffer.setnchannels(1)
|
||||||
|
buffer.setsampwidth(2)
|
||||||
|
buffer.setframerate(16000)
|
||||||
|
buffer.writeframes(data)
|
||||||
|
file_obj.seek(0)
|
||||||
|
|
||||||
|
options = get_options()
|
||||||
|
|
||||||
|
result = transcriber.inference(
|
||||||
|
audio=file_obj,
|
||||||
|
# Enter translate mode if target language is English
|
||||||
|
task="translate" if lang == "en-US" else "transcribe",
|
||||||
|
language=None, # type: ignore
|
||||||
|
verbose=False,
|
||||||
|
live=False,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
text = result.get("text", "")
|
||||||
|
text = ccc.convert(text)
|
||||||
|
print("result", text)
|
||||||
|
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"status": 0,
|
||||||
|
"segment": 0,
|
||||||
|
"result": {"hypotheses": [{"transcript": text}], "final": True},
|
||||||
|
"id": md5,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/konele/post")
|
||||||
|
async def translateapi(
|
||||||
|
request: Request,
|
||||||
|
lang: str = "und",
|
||||||
|
):
|
||||||
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
print("downloading request file", content_type)
|
||||||
|
splited = [i.strip() for i in content_type.split(",") if "=" in i]
|
||||||
|
info = {k: v for k, v in (i.split("=") for i in splited)}
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
channels = int(info.get("channels", "1"))
|
||||||
|
rate = int(info.get("rate", "16000"))
|
||||||
|
|
||||||
|
body = await request.body()
|
||||||
|
md5 = hashlib.md5(body).hexdigest()
|
||||||
|
|
||||||
|
# create fake file for wave.open
|
||||||
|
file_obj = io.BytesIO()
|
||||||
|
|
||||||
|
buffer = wave.open(file_obj, "wb")
|
||||||
|
buffer.setnchannels(channels)
|
||||||
|
buffer.setsampwidth(2)
|
||||||
|
buffer.setframerate(rate)
|
||||||
|
buffer.writeframes(body)
|
||||||
|
file_obj.seek(0)
|
||||||
|
|
||||||
|
options = get_options()
|
||||||
|
|
||||||
|
result = transcriber.inference(
|
||||||
|
audio=file_obj,
|
||||||
|
# Enter translate mode if target language is English
|
||||||
|
task="translate" if lang == "en-US" else "transcribe",
|
||||||
|
language=None, # type: ignore
|
||||||
|
verbose=False,
|
||||||
|
live=False,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
text = result.get("text", "")
|
||||||
|
text = ccc.convert(text)
|
||||||
|
print("result", text)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": 0,
|
||||||
|
"hypotheses": [{"utterance": text}],
|
||||||
|
"id": md5,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions")
|
||||||
|
async def transcription(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
prompt: str = Form(""),
|
||||||
|
response_type: str = Form("json"),
|
||||||
|
):
|
||||||
|
"""Transcription endpoint
|
||||||
|
|
||||||
|
User upload audio file in multipart/form-data format and receive transcription in response
|
||||||
|
"""
|
||||||
|
|
||||||
|
# timestamp as filename, keep original extension
|
||||||
|
options = get_options(initial_prompt=prompt)
|
||||||
|
|
||||||
|
result: Any = transcriber.inference(
|
||||||
|
audio=io.BytesIO(file.file.read()),
|
||||||
|
task="transcribe",
|
||||||
|
language=None, # type: ignore
|
||||||
|
verbose=False,
|
||||||
|
live=False,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response_type == "json":
|
||||||
|
return result
|
||||||
|
elif response_type == "text":
|
||||||
|
return result["text"].strip()
|
||||||
|
elif response_type == "tsv":
|
||||||
|
return generate_tsv(result)
|
||||||
|
elif response_type == "srt":
|
||||||
|
return generate_srt(result)
|
||||||
|
elif response_type == "vtt":
|
||||||
|
return generate_vtt(result)
|
||||||
|
|
||||||
|
return {"error": "Invalid response_type"}
|
||||||
|
|
||||||
|
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port)
|
||||||
Reference in New Issue
Block a user