Compare commits

..

7 Commits

Author SHA1 Message Date
460ad77a2f fix gpt refine prompt 2024-12-06 23:34:14 +08:00
890da4f4ac add env OPENAI_LLM_MODEL 2024-12-06 18:01:22 +08:00
4784bd53a2 add gpt refine 2024-12-06 17:53:04 +08:00
bd2c6b95cf update faster-whisper 2024-11-28 18:52:00 +08:00
0e46bd91d4 format code 2024-11-21 22:45:02 +08:00
99272b230f Upgrade Dependency 2024-11-21 22:44:49 +08:00
3c01a76405 Convert Traditional Chinese to Simplified Chinese 2024-11-21 22:44:27 +08:00
6 changed files with 137 additions and 38 deletions

View File

@@ -1 +1,2 @@
/venv
/.git

1
.gitignore vendored
View File

@@ -1 +1,2 @@
/venv
/.git

View File

@@ -10,6 +10,21 @@ Whisper-FastAPI is a very simple Python FastAPI interface for konele and OpenAI
- **Audio Transcriptions**: The `/v1/audio/transcriptions` endpoint allows users to upload an audio file and receive transcription in response, with an optional `response_type` parameter. The `response_type` can be 'json', 'text', 'tsv', 'srt', and 'vtt'.
- **Simplified Chinese**: The traditional Chinese will be automatically convert to simplified Chinese for konele using `opencc` library.
## GPT Refine Result
You can choose to use the OpenAI GPT model for post-processing transcription results. You can also provide context to GPT to allow it to modify the text based on your context.
Set the environment variables `OPENAI_BASE_URL=https://api.openai.com/v1` and `OPENAI_API_KEY=your-sk` to enable this feature.
When the client sends a request with `gpt_refine=True`, this feature will be activated. Specifically:
- For `/v1/audio/transcriptions`, submit using `curl <api_url> -F file=audio.mp4 -F gpt_refine=True`.
- For `/v1/konele/ws` and `/v1/konele/post`, use the URL format `/v1/konele/ws/gpt_refine`.
The default model is `gpt-4o-mini` set by environment variable `OPENAI_LLM_MODEL`.
You can easily edit the code LLM's prompt to better fit your workflow. It's just a few lines of code. Give it a try, it's very simple!
## Usage
### Konele Voice Typing
@@ -19,7 +34,7 @@ For konele voice typing, you can use either the websocket endpoint or the POST m
- **Websocket**: Connect to the websocket at `/konele/ws` (or `/v1/konele/ws`) and send audio data. The server will respond with the transcription or translation.
- **POST Method**: Send a POST request to `/konele/post` (or `/v1/konele/post`) with the audio data in the body. The server will respond with the transcription or translation.
You can also use the demo I have created to quickly test the effect at <https://yongyuancv.cn/v1/konele/ws> and <https://yongyuancv.cn/v1/konele/post>
You can also use the demo I have created to quickly test the effect at <https://yongyuancv.cn/v1/konele/post>
### OpenAI Whisper Service

View File

@@ -4,5 +4,6 @@ uvicorn[standard]
whisper_ctranslate2
opencc
prometheus-fastapi-instrumentator
git+https://github.com/SYSTRAN/faster-whisper@be9fb36ed356b9e299b125de6ee91862e0ac9038
git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9
pydub
aiohttp

View File

@@ -1,6 +1,11 @@
aiohappyeyeballs==2.4.4
aiohttp==3.11.10
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.2.post1
av==13.1.0
anyio==4.7.0
async-timeout==5.0.1
attrs==24.2.0
av==14.0.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.4.0
@@ -8,42 +13,46 @@ click==8.1.7
coloredlogs==15.0.1
ctranslate2==4.5.0
exceptiongroup==1.2.2
fastapi==0.115.5
faster-whisper @ git+https://github.com/SYSTRAN/faster-whisper@be9fb36ed356b9e299b125de6ee91862e0ac9038
fastapi==0.115.6
faster-whisper @ git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9
filelock==3.16.1
flatbuffers==24.3.25
frozenlist==1.5.0
fsspec==2024.10.0
h11==0.14.0
httptools==0.6.4
huggingface-hub==0.26.2
huggingface-hub==0.26.3
humanfriendly==10.0
idna==3.10
mpmath==1.3.0
multidict==6.1.0
numpy==2.1.3
onnxruntime==1.20.0
onnxruntime==1.20.1
OpenCC==1.1.9
packaging==24.2
prometheus-fastapi-instrumentator==7.0.0
prometheus_client==0.21.0
protobuf==5.28.3
prometheus_client==0.21.1
propcache==0.2.1
protobuf==5.29.1
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydantic==2.10.3
pydantic_core==2.27.1
pydub==0.25.1
python-dotenv==1.0.1
python-multipart==0.0.17
python-multipart==0.0.19
PyYAML==6.0.2
requests==2.32.3
sniffio==1.3.1
sounddevice==0.5.1
starlette==0.41.2
starlette==0.41.3
sympy==1.13.3
tokenizers==0.20.3
tqdm==4.67.0
tokenizers==0.21.0
tqdm==4.67.1
typing_extensions==4.12.2
urllib3==2.2.3
uvicorn==0.32.0
uvicorn==0.32.1
uvloop==0.21.0
watchfiles==0.24.0
watchfiles==1.0.0
websockets==14.1
whisper-ctranslate2==0.4.7
whisper-ctranslate2==0.5.0
yarl==1.18.3

View File

@@ -1,8 +1,10 @@
import aiohttp
import os
import sys
import dataclasses
import faster_whisper
import tqdm
import json
from fastapi.responses import StreamingResponse
from fastapi.responses import PlainTextResponse, StreamingResponse
import wave
import pydub
import io
@@ -21,12 +23,19 @@ from fastapi import (
WebSocket,
)
from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp
from faster_whisper.transcribe import Segment, TranscriptionInfo
import opencc
from prometheus_fastapi_instrumentator import Instrumentator
# redirect print to stderr
_print = print
def print(*args, **kwargs):
_print(*args, file=sys.stderr, **kwargs)
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int)
@@ -61,6 +70,51 @@ app.add_middleware(
)
async def gpt_refine_text(
ge: Generator[Segment, None, None], info: TranscriptionInfo, context: str
) -> str:
text = build_json_result(ge, info).text.strip()
model = os.environ.get("OPENAI_LLM_MODEL", "gpt-4o-mini")
if not text:
return ""
body: dict = {
"model": model,
"temperature": 0.1,
"stream": False,
"messages": [
{
"role": "system",
"content": f"""
You are a audio transcription text refiner. You may refer to the context to correct the transcription text.
Your task is to correct the transcribed text by removing redundant and repetitive words, resolving any contradictions, and fixing punctuation errors.
Keep my spoken language as it is, and do not change my speaking style. Only fix the text.
Response directly with the text.
""".strip(),
},
{
"role": "user",
"content": f"""
context: {context}
---
transcription: {text}
""".strip(),
},
],
}
print(f"Refining text length: {len(text)} with {model}")
print(body)
async with aiohttp.ClientSession() as session:
async with session.post(
os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
+ "/chat/completions",
json=body,
headers={
"Authorization": f'Bearer {os.environ["OPENAI_API_KEY"]}',
},
) as response:
return (await response.json())["choices"][0]["message"]["content"]
def stream_writer(generator: Generator[Segment, Any, None]):
for segment in generator:
yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
@@ -95,7 +149,7 @@ def srt_writer(generator: Generator[Segment, Any, None]):
def vtt_writer(generator: Generator[Segment, Any, None]):
yield "WEBVTT\n\n"
for i, segment in enumerate(generator):
for _, segment in enumerate(generator):
start_time = format_timestamp(segment.start)
end_time = format_timestamp(segment.end)
text = segment.text.strip()
@@ -107,15 +161,16 @@ class JsonResult(TranscriptionInfo):
segments: list[Segment]
text: str
def build_json_result(
generator: Iterable[Segment],
info: TranscriptionInfo,
info: TranscriptionInfo,
) -> JsonResult:
segments = [i for i in generator]
return JsonResult(
text="\n".join(i.text for i in segments),
segments=segments,
**dataclasses.asdict(info)
**dataclasses.asdict(info),
)
@@ -140,14 +195,12 @@ def stream_builder(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
def wrap():
last_pos = 0
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
for segment in segments:
start, end, text = segment.start, segment.end, segment.text
pbar.update(end - last_pos)
last_pos = end
yield segment
for segment in segments:
if info.language == "zh":
segment.text = ccc.convert(segment.text)
yield segment
return wrap(), info
@@ -166,8 +219,12 @@ async def konele_status(
@app.websocket("/k6nele/ws")
@app.websocket("/konele/ws")
@app.websocket("/konele/ws/gpt_refine")
@app.websocket("/k6nele/ws/gpt_refine")
@app.websocket("/v1/k6nele/ws")
@app.websocket("/v1/konele/ws")
@app.websocket("/v1/konele/ws/gpt_refine")
@app.websocket("/v1/k6nele/ws/gpt_refine")
async def konele_ws(
websocket: WebSocket,
task: Literal["transcribe", "translate"] = "transcribe",
@@ -215,13 +272,17 @@ async def konele_ws(
language=None if lang == "und" else lang,
initial_prompt=initial_prompt,
)
result = build_json_result(generator, info)
if websocket.url.path.endswith("gpt_refine"):
result = await gpt_refine_text(generator, info, initial_prompt)
else:
result = build_json_result(generator, info).text
await websocket.send_json(
{
"status": 0,
"segment": 0,
"result": {"hypotheses": [{"transcript": result.text}], "final": True},
"result": {"hypotheses": [{"transcript": result}], "final": True},
"id": md5,
}
)
@@ -230,8 +291,12 @@ async def konele_ws(
@app.post("/k6nele/post")
@app.post("/konele/post")
@app.post("/k6nele/post/gpt_refine")
@app.post("/konele/post/gpt_refine")
@app.post("/v1/k6nele/post")
@app.post("/v1/konele/post")
@app.post("/v1/k6nele/post/gpt_refine")
@app.post("/v1/konele/post/gpt_refine")
async def translateapi(
request: Request,
task: Literal["transcribe", "translate"] = "transcribe",
@@ -276,11 +341,15 @@ async def translateapi(
language=None if lang == "und" else lang,
initial_prompt=initial_prompt,
)
result = build_json_result(generator, info)
if request.url.path.endswith("gpt_refine"):
result = await gpt_refine_text(generator, info, initial_prompt)
else:
result = build_json_result(generator, info).text
return {
"status": 0,
"hypotheses": [{"utterance": result.text}],
"hypotheses": [{"utterance": result}],
"id": md5,
}
@@ -296,6 +365,7 @@ async def transcription(
language: str = Form("und"),
vad_filter: bool = Form(False),
repetition_penalty: float = Form(1.0),
gpt_refine: bool = Form(False),
):
"""Transcription endpoint
@@ -303,9 +373,9 @@ async def transcription(
"""
if not task:
if request.url.path == '/v1/audio/transcriptions':
if request.url.path == "/v1/audio/transcriptions":
task = "transcribe"
elif request.url.path == '/v1/audio/translations':
elif request.url.path == "/v1/audio/translations":
task = "translate"
else:
raise HTTPException(400, "task parameter is required")
@@ -329,6 +399,8 @@ async def transcription(
elif response_format == "json":
return build_json_result(generator, info)
elif response_format == "text":
if gpt_refine:
return PlainTextResponse(await gpt_refine_text(generator, info, prompt))
return StreamingResponse(text_writer(generator), media_type="text/plain")
elif response_format == "tsv":
return StreamingResponse(tsv_writer(generator), media_type="text/plain")