diff --git a/.dockerignore b/.dockerignore index f9606a3..5bb1b25 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,2 @@ /venv +/.git diff --git a/.gitignore b/.gitignore index f9606a3..5bb1b25 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /venv +/.git diff --git a/README.md b/README.md index 0c9543c..7732114 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,19 @@ Whisper-FastAPI is a very simple Python FastAPI interface for konele and OpenAI - **Audio Transcriptions**: The `/v1/audio/transcriptions` endpoint allows users to upload an audio file and receive transcription in response, with an optional `response_type` parameter. The `response_type` can be 'json', 'text', 'tsv', 'srt', and 'vtt'. - **Simplified Chinese**: The traditional Chinese will be automatically convert to simplified Chinese for konele using `opencc` library. +## GPT Refine Result + +You can choose to use the OpenAI GPT model for post-processing transcription results. You can also provide context to GPT to allow it to modify the text based on your context. + +Set the environment variables `OPENAI_BASE_URL=https://api.openai.com/v1` and `OPENAI_API_KEY=your-sk` to enable this feature. + +When the client sends a request with `gpt_refine=True`, this feature will be activated. Specifically: + +- For `/v1/audio/transcriptions`, submit using `curl -F file=audio.mp4 -F gpt_refine=True`. +- For `/v1/konele/ws` and `/v1/konele/post`, use the URL format `/v1/konele/ws/gpt_refine`. + +The default model is `gpt-4o-mini`. You can easily edit the code to change the or LLM's prompt to better fit your workflow. It's just a few lines of code. Give it a try, it's very simple! + ## Usage ### Konele Voice Typing @@ -19,7 +32,7 @@ For konele voice typing, you can use either the websocket endpoint or the POST m - **Websocket**: Connect to the websocket at `/konele/ws` (or `/v1/konele/ws`) and send audio data. The server will respond with the transcription or translation. - **POST Method**: Send a POST request to `/konele/post` (or `/v1/konele/post`) with the audio data in the body. The server will respond with the transcription or translation. -You can also use the demo I have created to quickly test the effect at and +You can also use the demo I have created to quickly test the effect at ### OpenAI Whisper Service diff --git a/requirements.txt b/requirements.txt index c441ebb..fd0a698 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ opencc prometheus-fastapi-instrumentator git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9 pydub +aiohttp diff --git a/requirements_version.txt b/requirements_version.txt index ea5d485..823f819 100644 --- a/requirements_version.txt +++ b/requirements_version.txt @@ -1,6 +1,11 @@ +aiohappyeyeballs==2.4.4 +aiohttp==3.11.10 +aiosignal==1.3.1 annotated-types==0.7.0 -anyio==4.6.2.post1 -av==13.1.0 +anyio==4.7.0 +async-timeout==5.0.1 +attrs==24.2.0 +av==14.0.0 certifi==2024.8.30 cffi==1.17.1 charset-normalizer==3.4.0 @@ -8,42 +13,46 @@ click==8.1.7 coloredlogs==15.0.1 ctranslate2==4.5.0 exceptiongroup==1.2.2 -fastapi==0.115.5 +fastapi==0.115.6 faster-whisper @ git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9 filelock==3.16.1 flatbuffers==24.3.25 +frozenlist==1.5.0 fsspec==2024.10.0 h11==0.14.0 httptools==0.6.4 -huggingface-hub==0.26.2 +huggingface-hub==0.26.3 humanfriendly==10.0 idna==3.10 mpmath==1.3.0 +multidict==6.1.0 numpy==2.1.3 onnxruntime==1.20.1 OpenCC==1.1.9 packaging==24.2 prometheus-fastapi-instrumentator==7.0.0 -prometheus_client==0.21.0 -protobuf==5.28.3 +prometheus_client==0.21.1 +propcache==0.2.1 +protobuf==5.29.1 pycparser==2.22 -pydantic==2.10.1 +pydantic==2.10.3 pydantic_core==2.27.1 pydub==0.25.1 python-dotenv==1.0.1 -python-multipart==0.0.17 +python-multipart==0.0.19 PyYAML==6.0.2 requests==2.32.3 sniffio==1.3.1 sounddevice==0.5.1 starlette==0.41.3 sympy==1.13.3 -tokenizers==0.20.3 -tqdm==4.67.0 +tokenizers==0.21.0 +tqdm==4.67.1 typing_extensions==4.12.2 urllib3==2.2.3 uvicorn==0.32.1 uvloop==0.21.0 -watchfiles==0.24.0 +watchfiles==1.0.0 websockets==14.1 -whisper-ctranslate2==0.4.8 +whisper-ctranslate2==0.5.0 +yarl==1.18.3 diff --git a/whisper_fastapi.py b/whisper_fastapi.py index e58fb60..5ee6cc2 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -1,8 +1,10 @@ +import aiohttp +import os import sys import dataclasses import faster_whisper import json -from fastapi.responses import StreamingResponse +from fastapi.responses import PlainTextResponse, StreamingResponse import wave import pydub import io @@ -28,9 +30,12 @@ from prometheus_fastapi_instrumentator import Instrumentator # redirect print to stderr _print = print + + def print(*args, **kwargs): _print(*args, file=sys.stderr, **kwargs) + parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", type=str) parser.add_argument("--port", default=5000, type=int) @@ -65,6 +70,45 @@ app.add_middleware( ) +async def gpt_refine_text( + ge: Generator[Segment, None, None], info: TranscriptionInfo, context: str +) -> str: + text = build_json_result(ge, info).text.strip() + if not text: + return "" + async with aiohttp.ClientSession() as session: + async with session.post( + os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + + "/chat/completions", + json={ + "model": "gpt-4o-mini", + "temperature": 0.1, + "stream": False, + "messages": [ + { + "role": "system", + "content": f""" +You are a audio transcription text refiner. +You may refeer to the context to refine the transcription text. + """.strip(), + }, + { + "role": "user", + "content": f""" +context: {context} +--- +transcription: {text} + """.strip(), + }, + ], + }, + headers={ + "Authorization": f'Bearer {os.environ["OPENAI_API_KEY"]}', + }, + ) as response: + return (await response.json())["choices"][0]["message"]["content"] + + def stream_writer(generator: Generator[Segment, Any, None]): for segment in generator: yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n" @@ -169,8 +213,12 @@ async def konele_status( @app.websocket("/k6nele/ws") @app.websocket("/konele/ws") +@app.websocket("/konele/ws/gpt_refine") +@app.websocket("/k6nele/ws/gpt_refine") @app.websocket("/v1/k6nele/ws") @app.websocket("/v1/konele/ws") +@app.websocket("/v1/konele/ws/gpt_refine") +@app.websocket("/v1/k6nele/ws/gpt_refine") async def konele_ws( websocket: WebSocket, task: Literal["transcribe", "translate"] = "transcribe", @@ -218,13 +266,17 @@ async def konele_ws( language=None if lang == "und" else lang, initial_prompt=initial_prompt, ) - result = build_json_result(generator, info) + + if websocket.url.path.endswith("gpt_refine"): + result = await gpt_refine_text(generator, info, initial_prompt) + else: + result = build_json_result(generator, info).text await websocket.send_json( { "status": 0, "segment": 0, - "result": {"hypotheses": [{"transcript": result.text}], "final": True}, + "result": {"hypotheses": [{"transcript": result}], "final": True}, "id": md5, } ) @@ -233,8 +285,12 @@ async def konele_ws( @app.post("/k6nele/post") @app.post("/konele/post") +@app.post("/k6nele/post/gpt_refine") +@app.post("/konele/post/gpt_refine") @app.post("/v1/k6nele/post") @app.post("/v1/konele/post") +@app.post("/v1/k6nele/post/gpt_refine") +@app.post("/v1/konele/post/gpt_refine") async def translateapi( request: Request, task: Literal["transcribe", "translate"] = "transcribe", @@ -279,11 +335,15 @@ async def translateapi( language=None if lang == "und" else lang, initial_prompt=initial_prompt, ) - result = build_json_result(generator, info) + + if request.url.path.endswith("gpt_refine"): + result = await gpt_refine_text(generator, info, initial_prompt) + else: + result = build_json_result(generator, info).text return { "status": 0, - "hypotheses": [{"utterance": result.text}], + "hypotheses": [{"utterance": result}], "id": md5, } @@ -299,6 +359,7 @@ async def transcription( language: str = Form("und"), vad_filter: bool = Form(False), repetition_penalty: float = Form(1.0), + gpt_refine: bool = Form(False), ): """Transcription endpoint @@ -332,6 +393,8 @@ async def transcription( elif response_format == "json": return build_json_result(generator, info) elif response_format == "text": + if gpt_refine: + return PlainTextResponse(await gpt_refine_text(generator, info, prompt)) return StreamingResponse(text_writer(generator), media_type="text/plain") elif response_format == "tsv": return StreamingResponse(tsv_writer(generator), media_type="text/plain")