add gpt refine

This commit is contained in:
2024-12-06 17:53:04 +08:00
parent bd2c6b95cf
commit 4784bd53a2
6 changed files with 106 additions and 18 deletions

View File

@@ -1,8 +1,10 @@
import aiohttp
import os
import sys
import dataclasses
import faster_whisper
import json
from fastapi.responses import StreamingResponse
from fastapi.responses import PlainTextResponse, StreamingResponse
import wave
import pydub
import io
@@ -28,9 +30,12 @@ from prometheus_fastapi_instrumentator import Instrumentator
# redirect print to stderr
_print = print
def print(*args, **kwargs):
_print(*args, file=sys.stderr, **kwargs)
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str)
parser.add_argument("--port", default=5000, type=int)
@@ -65,6 +70,45 @@ app.add_middleware(
)
async def gpt_refine_text(
ge: Generator[Segment, None, None], info: TranscriptionInfo, context: str
) -> str:
text = build_json_result(ge, info).text.strip()
if not text:
return ""
async with aiohttp.ClientSession() as session:
async with session.post(
os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
+ "/chat/completions",
json={
"model": "gpt-4o-mini",
"temperature": 0.1,
"stream": False,
"messages": [
{
"role": "system",
"content": f"""
You are a audio transcription text refiner.
You may refeer to the context to refine the transcription text.
""".strip(),
},
{
"role": "user",
"content": f"""
context: {context}
---
transcription: {text}
""".strip(),
},
],
},
headers={
"Authorization": f'Bearer {os.environ["OPENAI_API_KEY"]}',
},
) as response:
return (await response.json())["choices"][0]["message"]["content"]
def stream_writer(generator: Generator[Segment, Any, None]):
for segment in generator:
yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
@@ -169,8 +213,12 @@ async def konele_status(
@app.websocket("/k6nele/ws")
@app.websocket("/konele/ws")
@app.websocket("/konele/ws/gpt_refine")
@app.websocket("/k6nele/ws/gpt_refine")
@app.websocket("/v1/k6nele/ws")
@app.websocket("/v1/konele/ws")
@app.websocket("/v1/konele/ws/gpt_refine")
@app.websocket("/v1/k6nele/ws/gpt_refine")
async def konele_ws(
websocket: WebSocket,
task: Literal["transcribe", "translate"] = "transcribe",
@@ -218,13 +266,17 @@ async def konele_ws(
language=None if lang == "und" else lang,
initial_prompt=initial_prompt,
)
result = build_json_result(generator, info)
if websocket.url.path.endswith("gpt_refine"):
result = await gpt_refine_text(generator, info, initial_prompt)
else:
result = build_json_result(generator, info).text
await websocket.send_json(
{
"status": 0,
"segment": 0,
"result": {"hypotheses": [{"transcript": result.text}], "final": True},
"result": {"hypotheses": [{"transcript": result}], "final": True},
"id": md5,
}
)
@@ -233,8 +285,12 @@ async def konele_ws(
@app.post("/k6nele/post")
@app.post("/konele/post")
@app.post("/k6nele/post/gpt_refine")
@app.post("/konele/post/gpt_refine")
@app.post("/v1/k6nele/post")
@app.post("/v1/konele/post")
@app.post("/v1/k6nele/post/gpt_refine")
@app.post("/v1/konele/post/gpt_refine")
async def translateapi(
request: Request,
task: Literal["transcribe", "translate"] = "transcribe",
@@ -279,11 +335,15 @@ async def translateapi(
language=None if lang == "und" else lang,
initial_prompt=initial_prompt,
)
result = build_json_result(generator, info)
if request.url.path.endswith("gpt_refine"):
result = await gpt_refine_text(generator, info, initial_prompt)
else:
result = build_json_result(generator, info).text
return {
"status": 0,
"hypotheses": [{"utterance": result.text}],
"hypotheses": [{"utterance": result}],
"id": md5,
}
@@ -299,6 +359,7 @@ async def transcription(
language: str = Form("und"),
vad_filter: bool = Form(False),
repetition_penalty: float = Form(1.0),
gpt_refine: bool = Form(False),
):
"""Transcription endpoint
@@ -332,6 +393,8 @@ async def transcription(
elif response_format == "json":
return build_json_result(generator, info)
elif response_format == "text":
if gpt_refine:
return PlainTextResponse(await gpt_refine_text(generator, info, prompt))
return StreamingResponse(text_writer(generator), media_type="text/plain")
elif response_format == "tsv":
return StreamingResponse(tsv_writer(generator), media_type="text/plain")