add gpt refine

This commit is contained in:
2024-12-06 17:53:04 +08:00
parent bd2c6b95cf
commit 4784bd53a2
6 changed files with 106 additions and 18 deletions

View File

@@ -1 +1,2 @@
/venv /venv
/.git

1
.gitignore vendored
View File

@@ -1 +1,2 @@
/venv /venv
/.git

View File

@@ -10,6 +10,19 @@ Whisper-FastAPI is a very simple Python FastAPI interface for konele and OpenAI
- **Audio Transcriptions**: The `/v1/audio/transcriptions` endpoint allows users to upload an audio file and receive transcription in response, with an optional `response_type` parameter. The `response_type` can be 'json', 'text', 'tsv', 'srt', and 'vtt'. - **Audio Transcriptions**: The `/v1/audio/transcriptions` endpoint allows users to upload an audio file and receive transcription in response, with an optional `response_type` parameter. The `response_type` can be 'json', 'text', 'tsv', 'srt', and 'vtt'.
- **Simplified Chinese**: The traditional Chinese will be automatically convert to simplified Chinese for konele using `opencc` library. - **Simplified Chinese**: The traditional Chinese will be automatically convert to simplified Chinese for konele using `opencc` library.
## GPT Refine Result
You can choose to use the OpenAI GPT model for post-processing transcription results. You can also provide context to GPT to allow it to modify the text based on your context.
Set the environment variables `OPENAI_BASE_URL=https://api.openai.com/v1` and `OPENAI_API_KEY=your-sk` to enable this feature.
When the client sends a request with `gpt_refine=True`, this feature will be activated. Specifically:
- For `/v1/audio/transcriptions`, submit using `curl <api_url> -F file=audio.mp4 -F gpt_refine=True`.
- For `/v1/konele/ws` and `/v1/konele/post`, use the URL format `/v1/konele/ws/gpt_refine`.
The default model is `gpt-4o-mini`. You can easily edit the code to change the or LLM's prompt to better fit your workflow. It's just a few lines of code. Give it a try, it's very simple!
## Usage ## Usage
### Konele Voice Typing ### Konele Voice Typing
@@ -19,7 +32,7 @@ For konele voice typing, you can use either the websocket endpoint or the POST m
- **Websocket**: Connect to the websocket at `/konele/ws` (or `/v1/konele/ws`) and send audio data. The server will respond with the transcription or translation. - **Websocket**: Connect to the websocket at `/konele/ws` (or `/v1/konele/ws`) and send audio data. The server will respond with the transcription or translation.
- **POST Method**: Send a POST request to `/konele/post` (or `/v1/konele/post`) with the audio data in the body. The server will respond with the transcription or translation. - **POST Method**: Send a POST request to `/konele/post` (or `/v1/konele/post`) with the audio data in the body. The server will respond with the transcription or translation.
You can also use the demo I have created to quickly test the effect at <https://yongyuancv.cn/v1/konele/ws> and <https://yongyuancv.cn/v1/konele/post> You can also use the demo I have created to quickly test the effect at <https://yongyuancv.cn/v1/konele/post>
### OpenAI Whisper Service ### OpenAI Whisper Service

View File

@@ -6,3 +6,4 @@ opencc
prometheus-fastapi-instrumentator prometheus-fastapi-instrumentator
git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9 git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9
pydub pydub
aiohttp

View File

@@ -1,6 +1,11 @@
aiohappyeyeballs==2.4.4
aiohttp==3.11.10
aiosignal==1.3.1
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.6.2.post1 anyio==4.7.0
av==13.1.0 async-timeout==5.0.1
attrs==24.2.0
av==14.0.0
certifi==2024.8.30 certifi==2024.8.30
cffi==1.17.1 cffi==1.17.1
charset-normalizer==3.4.0 charset-normalizer==3.4.0
@@ -8,42 +13,46 @@ click==8.1.7
coloredlogs==15.0.1 coloredlogs==15.0.1
ctranslate2==4.5.0 ctranslate2==4.5.0
exceptiongroup==1.2.2 exceptiongroup==1.2.2
fastapi==0.115.5 fastapi==0.115.6
faster-whisper @ git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9 faster-whisper @ git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9
filelock==3.16.1 filelock==3.16.1
flatbuffers==24.3.25 flatbuffers==24.3.25
frozenlist==1.5.0
fsspec==2024.10.0 fsspec==2024.10.0
h11==0.14.0 h11==0.14.0
httptools==0.6.4 httptools==0.6.4
huggingface-hub==0.26.2 huggingface-hub==0.26.3
humanfriendly==10.0 humanfriendly==10.0
idna==3.10 idna==3.10
mpmath==1.3.0 mpmath==1.3.0
multidict==6.1.0
numpy==2.1.3 numpy==2.1.3
onnxruntime==1.20.1 onnxruntime==1.20.1
OpenCC==1.1.9 OpenCC==1.1.9
packaging==24.2 packaging==24.2
prometheus-fastapi-instrumentator==7.0.0 prometheus-fastapi-instrumentator==7.0.0
prometheus_client==0.21.0 prometheus_client==0.21.1
protobuf==5.28.3 propcache==0.2.1
protobuf==5.29.1
pycparser==2.22 pycparser==2.22
pydantic==2.10.1 pydantic==2.10.3
pydantic_core==2.27.1 pydantic_core==2.27.1
pydub==0.25.1 pydub==0.25.1
python-dotenv==1.0.1 python-dotenv==1.0.1
python-multipart==0.0.17 python-multipart==0.0.19
PyYAML==6.0.2 PyYAML==6.0.2
requests==2.32.3 requests==2.32.3
sniffio==1.3.1 sniffio==1.3.1
sounddevice==0.5.1 sounddevice==0.5.1
starlette==0.41.3 starlette==0.41.3
sympy==1.13.3 sympy==1.13.3
tokenizers==0.20.3 tokenizers==0.21.0
tqdm==4.67.0 tqdm==4.67.1
typing_extensions==4.12.2 typing_extensions==4.12.2
urllib3==2.2.3 urllib3==2.2.3
uvicorn==0.32.1 uvicorn==0.32.1
uvloop==0.21.0 uvloop==0.21.0
watchfiles==0.24.0 watchfiles==1.0.0
websockets==14.1 websockets==14.1
whisper-ctranslate2==0.4.8 whisper-ctranslate2==0.5.0
yarl==1.18.3

View File

@@ -1,8 +1,10 @@
import aiohttp
import os
import sys import sys
import dataclasses import dataclasses
import faster_whisper import faster_whisper
import json import json
from fastapi.responses import StreamingResponse from fastapi.responses import PlainTextResponse, StreamingResponse
import wave import wave
import pydub import pydub
import io import io
@@ -28,9 +30,12 @@ from prometheus_fastapi_instrumentator import Instrumentator
# redirect print to stderr # redirect print to stderr
_print = print _print = print
def print(*args, **kwargs): def print(*args, **kwargs):
_print(*args, file=sys.stderr, **kwargs) _print(*args, file=sys.stderr, **kwargs)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str) parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int) parser.add_argument("--port", default=5000, type=int)
@@ -65,6 +70,45 @@ app.add_middleware(
) )
async def gpt_refine_text(
ge: Generator[Segment, None, None], info: TranscriptionInfo, context: str
) -> str:
text = build_json_result(ge, info).text.strip()
if not text:
return ""
async with aiohttp.ClientSession() as session:
async with session.post(
os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
+ "/chat/completions",
json={
"model": "gpt-4o-mini",
"temperature": 0.1,
"stream": False,
"messages": [
{
"role": "system",
"content": f"""
You are a audio transcription text refiner.
You may refeer to the context to refine the transcription text.
""".strip(),
},
{
"role": "user",
"content": f"""
context: {context}
---
transcription: {text}
""".strip(),
},
],
},
headers={
"Authorization": f'Bearer {os.environ["OPENAI_API_KEY"]}',
},
) as response:
return (await response.json())["choices"][0]["message"]["content"]
def stream_writer(generator: Generator[Segment, Any, None]): def stream_writer(generator: Generator[Segment, Any, None]):
for segment in generator: for segment in generator:
yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n" yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
@@ -169,8 +213,12 @@ async def konele_status(
@app.websocket("/k6nele/ws") @app.websocket("/k6nele/ws")
@app.websocket("/konele/ws") @app.websocket("/konele/ws")
@app.websocket("/konele/ws/gpt_refine")
@app.websocket("/k6nele/ws/gpt_refine")
@app.websocket("/v1/k6nele/ws") @app.websocket("/v1/k6nele/ws")
@app.websocket("/v1/konele/ws") @app.websocket("/v1/konele/ws")
@app.websocket("/v1/konele/ws/gpt_refine")
@app.websocket("/v1/k6nele/ws/gpt_refine")
async def konele_ws( async def konele_ws(
websocket: WebSocket, websocket: WebSocket,
task: Literal["transcribe", "translate"] = "transcribe", task: Literal["transcribe", "translate"] = "transcribe",
@@ -218,13 +266,17 @@ async def konele_ws(
language=None if lang == "und" else lang, language=None if lang == "und" else lang,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
) )
result = build_json_result(generator, info)
if websocket.url.path.endswith("gpt_refine"):
result = await gpt_refine_text(generator, info, initial_prompt)
else:
result = build_json_result(generator, info).text
await websocket.send_json( await websocket.send_json(
{ {
"status": 0, "status": 0,
"segment": 0, "segment": 0,
"result": {"hypotheses": [{"transcript": result.text}], "final": True}, "result": {"hypotheses": [{"transcript": result}], "final": True},
"id": md5, "id": md5,
} }
) )
@@ -233,8 +285,12 @@ async def konele_ws(
@app.post("/k6nele/post") @app.post("/k6nele/post")
@app.post("/konele/post") @app.post("/konele/post")
@app.post("/k6nele/post/gpt_refine")
@app.post("/konele/post/gpt_refine")
@app.post("/v1/k6nele/post") @app.post("/v1/k6nele/post")
@app.post("/v1/konele/post") @app.post("/v1/konele/post")
@app.post("/v1/k6nele/post/gpt_refine")
@app.post("/v1/konele/post/gpt_refine")
async def translateapi( async def translateapi(
request: Request, request: Request,
task: Literal["transcribe", "translate"] = "transcribe", task: Literal["transcribe", "translate"] = "transcribe",
@@ -279,11 +335,15 @@ async def translateapi(
language=None if lang == "und" else lang, language=None if lang == "und" else lang,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
) )
result = build_json_result(generator, info)
if request.url.path.endswith("gpt_refine"):
result = await gpt_refine_text(generator, info, initial_prompt)
else:
result = build_json_result(generator, info).text
return { return {
"status": 0, "status": 0,
"hypotheses": [{"utterance": result.text}], "hypotheses": [{"utterance": result}],
"id": md5, "id": md5,
} }
@@ -299,6 +359,7 @@ async def transcription(
language: str = Form("und"), language: str = Form("und"),
vad_filter: bool = Form(False), vad_filter: bool = Form(False),
repetition_penalty: float = Form(1.0), repetition_penalty: float = Form(1.0),
gpt_refine: bool = Form(False),
): ):
"""Transcription endpoint """Transcription endpoint
@@ -332,6 +393,8 @@ async def transcription(
elif response_format == "json": elif response_format == "json":
return build_json_result(generator, info) return build_json_result(generator, info)
elif response_format == "text": elif response_format == "text":
if gpt_refine:
return PlainTextResponse(await gpt_refine_text(generator, info, prompt))
return StreamingResponse(text_writer(generator), media_type="text/plain") return StreamingResponse(text_writer(generator), media_type="text/plain")
elif response_format == "tsv": elif response_format == "tsv":
return StreamingResponse(tsv_writer(generator), media_type="text/plain") return StreamingResponse(tsv_writer(generator), media_type="text/plain")