From 046f4017d09e31281915f06bac4369916a674e71 Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Tue, 17 Oct 2023 20:52:18 +0800 Subject: [PATCH] init --- .gitignore | 1 + README.md | 52 +++++++++ requirements.txt | 5 + requirements_version.txt | 39 +++++++ whisper_fastapi.py | 247 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 344 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 requirements.txt create mode 100644 requirements_version.txt create mode 100644 whisper_fastapi.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f9606a3 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/venv diff --git a/README.md b/README.md new file mode 100644 index 0000000..66c52bc --- /dev/null +++ b/README.md @@ -0,0 +1,52 @@ +# Whisper-FastAPI + +Whisper-FastAPI is a very simple Python FastAPI interface for konele and OpenAI services. It is based on the `faster-whisper` project and provides an API for konele-like interface, where translations and transcriptions can be obtained by connecting over websockets or POST requests. + +## Features + +- **Translation and Transcription**: The application provides an API for konele service, where translations and transcriptions can be obtained by connecting over websockets or POST requests. +- **Language Support**: If the target language is English, then the application will translate any source language to English. +- **Websocket and POST Method Support**: The project supports a websocket (`/konele/ws`) and a POST method to `/konele/post`. +- **Audio Transcriptions**: The `/v1/audio/transcriptions` endpoint allows users to upload an audio file and receive transcription in response, with an optional `response_type` parameter. The `response_type` can be 'json', 'text', 'tsv', 'srt', and 'vtt'. +- **Simplified Chinese**: The traditional Chinese will be automatically convert to simplified Chinese for konele using `opencc` library. + +## Usage + +### Konele Voice Typing + +For konele voice typing, you can use either the websocket endpoint or the POST method endpoint. + +- **Websocket**: Connect to the websocket at `/konele/ws` and send audio data. The server will respond with the transcription or translation. +- **POST Method**: Send a POST request to `/konele/post` with the audio data in the body. The server will respond with the transcription or translation. + +You can also use the demo I have created to quickly test the effect at and + +### OpenAI Whisper Service + +To use the service that matches the structure of the OpenAI Whisper service, send a POST request to `/v1/audio/transcriptions` with an audio file. The server will respond with the transcription in the format specified by the `response_type` parameter. + +You can also use the demo I have created to quickly test the effect at + +My demo is using the large-v2 model on RTX3060. + +## Getting Started + +To run the application, you need to have Python installed on your machine. You can then clone the repository and install the required dependencies. + +```bash +git clone https://github.com/heimoshuiyu/whisper-fastapi.git +cd whisper-fastapi +pip install -r requirements.txt +``` + +You can then run the application using the following command: (model will be download from huggingface if not exists in cache dir) + +```bash +python main.py --host 0.0.0.0 --port 5000 --model large-v2 +``` + +This will start the application on `http://:5000`. + +## Limitation + +Defect: Due to the synchronous nature of inference, this API can actually only handle one request at a time. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..77e1bec --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +python-multipart +fastapi +uvicorn +whisper_ctranslate2 +opencc diff --git a/requirements_version.txt b/requirements_version.txt new file mode 100644 index 0000000..3174828 --- /dev/null +++ b/requirements_version.txt @@ -0,0 +1,39 @@ +annotated-types==0.6.0 +anyio==3.7.1 +av==10.0.0 +certifi==2023.7.22 +cffi==1.16.0 +charset-normalizer==3.3.0 +click==8.1.7 +coloredlogs==15.0.1 +ctranslate2==3.20.0 +fastapi==0.103.2 +faster-whisper==0.9.0 +filelock==3.12.4 +flatbuffers==23.5.26 +fsspec==2023.9.2 +h11==0.14.0 +huggingface-hub==0.17.3 +humanfriendly==10.0 +idna==3.4 +mpmath==1.3.0 +numpy==1.26.1 +onnxruntime==1.16.1 +OpenCC==1.1.7 +packaging==23.2 +protobuf==4.24.4 +pycparser==2.21 +pydantic==2.4.2 +pydantic_core==2.10.1 +PyYAML==6.0.1 +requests==2.31.0 +sniffio==1.3.0 +sounddevice==0.4.6 +starlette==0.27.0 +sympy==1.12 +tokenizers==0.14.1 +tqdm==4.66.1 +typing_extensions==4.8.0 +urllib3==2.0.6 +uvicorn==0.23.2 +whisper-ctranslate2==0.3.2 diff --git a/whisper_fastapi.py b/whisper_fastapi.py new file mode 100644 index 0000000..16f1355 --- /dev/null +++ b/whisper_fastapi.py @@ -0,0 +1,247 @@ +import wave +import io +import hashlib +import argparse +import uvicorn +from typing import Any +from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket +from fastapi.middleware.cors import CORSMiddleware +from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions +from src.whisper_ctranslate2.writers import format_timestamp +import opencc + +parser = argparse.ArgumentParser() +parser.add_argument("--host", default="0.0.0.0", type=str) +parser.add_argument("--port", default=5000, type=int) +parser.add_argument("--model", default="large-v2", type=str) +parser.add_argument("--cache_dir", default=None, type=str) +args = parser.parse_args() +app = FastAPI() +ccc = opencc.OpenCC("t2s.json") + +print("Loading model...") +transcriber = Transcribe( + model_path=args.model, + device="auto", + device_index=0, + compute_type="default", + threads=1, + cache_directory=args.cache_dir, + local_files_only=False, +) +print("Model loaded!") + + +# allow all cors +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +def generate_tsv(result: dict[str, list[Any]]): + tsv = "start\tend\ttext\n" + for i, segment in enumerate(result["segments"]): + start_time = str(round(1000 * segment["start"])) + end_time = str(round(1000 * segment["end"])) + text = segment["text"] + tsv += f"{start_time}\t{end_time}\t{text}\n" + return tsv + + +def generate_srt(result: dict[str, list[Any]]): + srt = "" + for i, segment in enumerate(result["segments"], start=1): + start_time = format_timestamp(segment["start"]) + end_time = format_timestamp(segment["end"]) + text = segment["text"] + srt += f"{i}\n{start_time} --> {end_time}\n{text}\n\n" + return srt + + +def generate_vtt(result: dict[str, list[Any]]): + vtt = "WEBVTT\n\n" + for segment in result["segments"]: + start_time = format_timestamp(segment["start"]) + end_time = format_timestamp(segment["end"]) + text = segment["text"] + vtt += f"{start_time} --> {end_time}\n{text}\n\n" + return vtt + + +def get_options(*, initial_prompt=""): + options = TranscriptionOptions( + beam_size=5, + best_of=5, + patience=1.0, + length_penalty=1.0, + log_prob_threshold=-1.0, + no_speech_threshold=0.6, + compression_ratio_threshold=2.4, + condition_on_previous_text=True, + temperature=[0.0, 1.0 + 1e-6, 0.2], + suppress_tokens=[-1], + word_timestamps=True, + print_colors=False, + prepend_punctuations="\"'“¿([{-", + append_punctuations="\"'.。,,!!??::”)]}、", + vad_filter=False, + vad_threshold=None, + vad_min_speech_duration_ms=None, + vad_max_speech_duration_s=None, + vad_min_silence_duration_ms=None, + initial_prompt=initial_prompt, + repetition_penalty=1.0, + no_repeat_ngram_size=0, + prompt_reset_on_temperature=False, + suppress_blank=False, + ) + return options + + +@app.websocket("/konele/ws") +async def konele_ws( + websocket: WebSocket, + lang: str = "und", +): + await websocket.accept() + print("WebSocket client connected, lang is", lang) + data = b"" + while True: + try: + data += await websocket.receive_bytes() + print("Received data:", len(data), data[-10:]) + if data[-3:] == b"EOS": + print("End of speech") + break + except: + break + + md5 = hashlib.md5(data).hexdigest() + + # create fake file for wave.open + file_obj = io.BytesIO() + + buffer = wave.open(file_obj, "wb") + buffer.setnchannels(1) + buffer.setsampwidth(2) + buffer.setframerate(16000) + buffer.writeframes(data) + file_obj.seek(0) + + options = get_options() + + result = transcriber.inference( + audio=file_obj, + # Enter translate mode if target language is English + task="translate" if lang == "en-US" else "transcribe", + language=None, # type: ignore + verbose=False, + live=False, + options=options, + ) + text = result.get("text", "") + text = ccc.convert(text) + print("result", text) + + await websocket.send_json( + { + "status": 0, + "segment": 0, + "result": {"hypotheses": [{"transcript": text}], "final": True}, + "id": md5, + } + ) + await websocket.close() + + +@app.post("/konele/post") +async def translateapi( + request: Request, + lang: str = "und", +): + content_type = request.headers.get("Content-Type", "") + print("downloading request file", content_type) + splited = [i.strip() for i in content_type.split(",") if "=" in i] + info = {k: v for k, v in (i.split("=") for i in splited)} + print(info) + + channels = int(info.get("channels", "1")) + rate = int(info.get("rate", "16000")) + + body = await request.body() + md5 = hashlib.md5(body).hexdigest() + + # create fake file for wave.open + file_obj = io.BytesIO() + + buffer = wave.open(file_obj, "wb") + buffer.setnchannels(channels) + buffer.setsampwidth(2) + buffer.setframerate(rate) + buffer.writeframes(body) + file_obj.seek(0) + + options = get_options() + + result = transcriber.inference( + audio=file_obj, + # Enter translate mode if target language is English + task="translate" if lang == "en-US" else "transcribe", + language=None, # type: ignore + verbose=False, + live=False, + options=options, + ) + text = result.get("text", "") + text = ccc.convert(text) + print("result", text) + + return { + "status": 0, + "hypotheses": [{"utterance": text}], + "id": md5, + } + + +@app.post("/v1/audio/transcriptions") +async def transcription( + file: UploadFile = File(...), + prompt: str = Form(""), + response_type: str = Form("json"), +): + """Transcription endpoint + + User upload audio file in multipart/form-data format and receive transcription in response + """ + + # timestamp as filename, keep original extension + options = get_options(initial_prompt=prompt) + + result: Any = transcriber.inference( + audio=io.BytesIO(file.file.read()), + task="transcribe", + language=None, # type: ignore + verbose=False, + live=False, + options=options, + ) + + if response_type == "json": + return result + elif response_type == "text": + return result["text"].strip() + elif response_type == "tsv": + return generate_tsv(result) + elif response_type == "srt": + return generate_srt(result) + elif response_type == "vtt": + return generate_vtt(result) + + return {"error": "Invalid response_type"} + + +uvicorn.run(app, host=args.host, port=args.port)