Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
460ad77a2f
|
|||
|
890da4f4ac
|
|||
|
4784bd53a2
|
|||
|
bd2c6b95cf
|
|||
|
0e46bd91d4
|
|||
|
99272b230f
|
|||
|
3c01a76405
|
|||
|
3401c59c4b
|
|||
|
76b32bc9c4
|
|||
|
4a5ba38f5e
|
|||
|
8ae81a124d
|
|||
|
0faaf0f301
|
|||
|
fab1ec9d03
|
|||
|
71bde08b17
|
|||
|
a53a2fb80e
|
@@ -1 +1,2 @@
|
|||||||
/venv
|
/venv
|
||||||
|
/.git
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
/venv
|
/venv
|
||||||
|
/.git
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM docker.io/nvidia/cuda:12.0.0-cudnn8-runtime-ubuntu22.04
|
FROM nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y ffmpeg python3 python3-pip git && \
|
apt-get install -y ffmpeg python3 python3-pip git && \
|
||||||
|
|||||||
23
README.md
23
README.md
@@ -5,21 +5,36 @@ Whisper-FastAPI is a very simple Python FastAPI interface for konele and OpenAI
|
|||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Translation and Transcription**: The application provides an API for konele service, where translations and transcriptions can be obtained by connecting over websockets or POST requests.
|
- **Translation and Transcription**: The application provides an API for konele service, where translations and transcriptions can be obtained by connecting over websockets or POST requests.
|
||||||
- **Language Support**: If the target language is English, then the application will translate any source language to English.
|
- **Language Support**: If no language is specified, the language will be automatically recognized from the first 30 seconds.
|
||||||
- **Websocket and POST Method Support**: The project supports a websocket (`/konele/ws`) and a POST method to `/konele/post`.
|
- **Websocket and POST Method Support**: The project supports a websocket (`/konele/ws`) and a POST method to `/konele/post`.
|
||||||
- **Audio Transcriptions**: The `/v1/audio/transcriptions` endpoint allows users to upload an audio file and receive transcription in response, with an optional `response_type` parameter. The `response_type` can be 'json', 'text', 'tsv', 'srt', and 'vtt'.
|
- **Audio Transcriptions**: The `/v1/audio/transcriptions` endpoint allows users to upload an audio file and receive transcription in response, with an optional `response_type` parameter. The `response_type` can be 'json', 'text', 'tsv', 'srt', and 'vtt'.
|
||||||
- **Simplified Chinese**: The traditional Chinese will be automatically convert to simplified Chinese for konele using `opencc` library.
|
- **Simplified Chinese**: The traditional Chinese will be automatically convert to simplified Chinese for konele using `opencc` library.
|
||||||
|
|
||||||
|
## GPT Refine Result
|
||||||
|
|
||||||
|
You can choose to use the OpenAI GPT model for post-processing transcription results. You can also provide context to GPT to allow it to modify the text based on your context.
|
||||||
|
|
||||||
|
Set the environment variables `OPENAI_BASE_URL=https://api.openai.com/v1` and `OPENAI_API_KEY=your-sk` to enable this feature.
|
||||||
|
|
||||||
|
When the client sends a request with `gpt_refine=True`, this feature will be activated. Specifically:
|
||||||
|
|
||||||
|
- For `/v1/audio/transcriptions`, submit using `curl <api_url> -F file=audio.mp4 -F gpt_refine=True`.
|
||||||
|
- For `/v1/konele/ws` and `/v1/konele/post`, use the URL format `/v1/konele/ws/gpt_refine`.
|
||||||
|
|
||||||
|
The default model is `gpt-4o-mini` set by environment variable `OPENAI_LLM_MODEL`.
|
||||||
|
|
||||||
|
You can easily edit the code LLM's prompt to better fit your workflow. It's just a few lines of code. Give it a try, it's very simple!
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Konele Voice Typing
|
### Konele Voice Typing
|
||||||
|
|
||||||
For konele voice typing, you can use either the websocket endpoint or the POST method endpoint.
|
For konele voice typing, you can use either the websocket endpoint or the POST method endpoint.
|
||||||
|
|
||||||
- **Websocket**: Connect to the websocket at `/konele/ws` and send audio data. The server will respond with the transcription or translation.
|
- **Websocket**: Connect to the websocket at `/konele/ws` (or `/v1/konele/ws`) and send audio data. The server will respond with the transcription or translation.
|
||||||
- **POST Method**: Send a POST request to `/konele/post` with the audio data in the body. The server will respond with the transcription or translation.
|
- **POST Method**: Send a POST request to `/konele/post` (or `/v1/konele/post`) with the audio data in the body. The server will respond with the transcription or translation.
|
||||||
|
|
||||||
You can also use the demo I have created to quickly test the effect at <https://yongyuancv.cn/konele/ws> and <https://yongyuancv.cn/konele/post>
|
You can also use the demo I have created to quickly test the effect at <https://yongyuancv.cn/v1/konele/post>
|
||||||
|
|
||||||
### OpenAI Whisper Service
|
### OpenAI Whisper Service
|
||||||
|
|
||||||
|
|||||||
@@ -4,5 +4,6 @@ uvicorn[standard]
|
|||||||
whisper_ctranslate2
|
whisper_ctranslate2
|
||||||
opencc
|
opencc
|
||||||
prometheus-fastapi-instrumentator
|
prometheus-fastapi-instrumentator
|
||||||
git+https://github.com/heimoshuiyu/faster-whisper@prompt
|
git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9
|
||||||
pydub
|
pydub
|
||||||
|
aiohttp
|
||||||
|
|||||||
@@ -1,49 +1,58 @@
|
|||||||
|
aiohappyeyeballs==2.4.4
|
||||||
|
aiohttp==3.11.10
|
||||||
|
aiosignal==1.3.1
|
||||||
annotated-types==0.7.0
|
annotated-types==0.7.0
|
||||||
anyio==4.4.0
|
anyio==4.7.0
|
||||||
av==12.3.0
|
async-timeout==5.0.1
|
||||||
|
attrs==24.2.0
|
||||||
|
av==14.0.0
|
||||||
certifi==2024.8.30
|
certifi==2024.8.30
|
||||||
cffi==1.17.1
|
cffi==1.17.1
|
||||||
charset-normalizer==3.3.2
|
charset-normalizer==3.4.0
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
coloredlogs==15.0.1
|
coloredlogs==15.0.1
|
||||||
ctranslate2==4.4.0
|
ctranslate2==4.5.0
|
||||||
exceptiongroup==1.2.2
|
exceptiongroup==1.2.2
|
||||||
fastapi==0.114.1
|
fastapi==0.115.6
|
||||||
faster-whisper @ git+https://github.com/heimoshuiyu/faster-whisper@28a4d11a736d8cdeb4655ee5d7e4b4e7ae5ec8e0
|
faster-whisper @ git+https://github.com/heimoshuiyu/faster-whisper@a759f5f48f5ef5b79461a6461966eafe9df088a9
|
||||||
filelock==3.16.0
|
filelock==3.16.1
|
||||||
flatbuffers==24.3.25
|
flatbuffers==24.3.25
|
||||||
fsspec==2024.9.0
|
frozenlist==1.5.0
|
||||||
|
fsspec==2024.10.0
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
httptools==0.6.1
|
httptools==0.6.4
|
||||||
huggingface-hub==0.24.6
|
huggingface-hub==0.26.3
|
||||||
humanfriendly==10.0
|
humanfriendly==10.0
|
||||||
idna==3.8
|
idna==3.10
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
numpy==2.1.1
|
multidict==6.1.0
|
||||||
onnxruntime==1.19.2
|
numpy==2.1.3
|
||||||
|
onnxruntime==1.20.1
|
||||||
OpenCC==1.1.9
|
OpenCC==1.1.9
|
||||||
packaging==24.1
|
packaging==24.2
|
||||||
prometheus-fastapi-instrumentator==7.0.0
|
prometheus-fastapi-instrumentator==7.0.0
|
||||||
prometheus_client==0.20.0
|
prometheus_client==0.21.1
|
||||||
protobuf==5.28.0
|
propcache==0.2.1
|
||||||
|
protobuf==5.29.1
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
pydantic==2.9.1
|
pydantic==2.10.3
|
||||||
pydantic_core==2.23.3
|
pydantic_core==2.27.1
|
||||||
pydub==0.25.1
|
pydub==0.25.1
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.19
|
||||||
PyYAML==6.0.2
|
PyYAML==6.0.2
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
sounddevice==0.5.0
|
sounddevice==0.5.1
|
||||||
starlette==0.38.5
|
starlette==0.41.3
|
||||||
sympy==1.13.2
|
sympy==1.13.3
|
||||||
tokenizers==0.20.0
|
tokenizers==0.21.0
|
||||||
tqdm==4.66.5
|
tqdm==4.67.1
|
||||||
typing_extensions==4.12.2
|
typing_extensions==4.12.2
|
||||||
urllib3==2.2.2
|
urllib3==2.2.3
|
||||||
uvicorn==0.30.6
|
uvicorn==0.32.1
|
||||||
uvloop==0.20.0
|
uvloop==0.21.0
|
||||||
watchfiles==0.24.0
|
watchfiles==1.0.0
|
||||||
websockets==13.0.1
|
websockets==14.1
|
||||||
whisper-ctranslate2==0.4.5
|
whisper-ctranslate2==0.5.0
|
||||||
|
yarl==1.18.3
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
import tqdm
|
import aiohttp
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import dataclasses
|
||||||
|
import faster_whisper
|
||||||
import json
|
import json
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import PlainTextResponse, StreamingResponse
|
||||||
import wave
|
import wave
|
||||||
import pydub
|
import pydub
|
||||||
import io
|
import io
|
||||||
import hashlib
|
import hashlib
|
||||||
import argparse
|
import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable
|
from typing import Annotated, Any, BinaryIO, Literal, Generator, Tuple, Iterable, Union
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
File,
|
File,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
@@ -19,12 +23,19 @@ from fastapi import (
|
|||||||
WebSocket,
|
WebSocket,
|
||||||
)
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
|
|
||||||
from src.whisper_ctranslate2.writers import format_timestamp
|
from src.whisper_ctranslate2.writers import format_timestamp
|
||||||
from faster_whisper.transcribe import Segment, TranscriptionInfo
|
from faster_whisper.transcribe import Segment, TranscriptionInfo
|
||||||
import opencc
|
import opencc
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
|
|
||||||
|
# redirect print to stderr
|
||||||
|
_print = print
|
||||||
|
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
_print(*args, file=sys.stderr, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", default="0.0.0.0", type=str)
|
parser.add_argument("--host", default="0.0.0.0", type=str)
|
||||||
parser.add_argument("--port", default=5000, type=int)
|
parser.add_argument("--port", default=5000, type=int)
|
||||||
@@ -40,16 +51,13 @@ Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
|
|||||||
ccc = opencc.OpenCC("t2s.json")
|
ccc = opencc.OpenCC("t2s.json")
|
||||||
|
|
||||||
print(f"Loading model to device {args.device}...")
|
print(f"Loading model to device {args.device}...")
|
||||||
transcriber = Transcribe(
|
model = faster_whisper.WhisperModel(
|
||||||
model_path=args.model,
|
model_size_or_path=args.model,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
device_index=0,
|
cpu_threads=args.threads,
|
||||||
compute_type="default",
|
|
||||||
threads=args.threads,
|
|
||||||
cache_directory=args.cache_dir,
|
|
||||||
local_files_only=args.local_files_only,
|
local_files_only=args.local_files_only,
|
||||||
)
|
)
|
||||||
print(f"Model loaded to device {transcriber.model.model.device}")
|
print(f"Model loaded to device {model.model.device}")
|
||||||
|
|
||||||
|
|
||||||
# allow all cors
|
# allow all cors
|
||||||
@@ -62,56 +70,108 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stream_writer(generator: Generator[dict[str, Any], Any, None]):
|
async def gpt_refine_text(
|
||||||
|
ge: Generator[Segment, None, None], info: TranscriptionInfo, context: str
|
||||||
|
) -> str:
|
||||||
|
text = build_json_result(ge, info).text.strip()
|
||||||
|
model = os.environ.get("OPENAI_LLM_MODEL", "gpt-4o-mini")
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
body: dict = {
|
||||||
|
"model": model,
|
||||||
|
"temperature": 0.1,
|
||||||
|
"stream": False,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"""
|
||||||
|
You are a audio transcription text refiner. You may refer to the context to correct the transcription text.
|
||||||
|
Your task is to correct the transcribed text by removing redundant and repetitive words, resolving any contradictions, and fixing punctuation errors.
|
||||||
|
Keep my spoken language as it is, and do not change my speaking style. Only fix the text.
|
||||||
|
Response directly with the text.
|
||||||
|
""".strip(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
|
context: {context}
|
||||||
|
---
|
||||||
|
transcription: {text}
|
||||||
|
""".strip(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
print(f"Refining text length: {len(text)} with {model}")
|
||||||
|
print(body)
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
+ "/chat/completions",
|
||||||
|
json=body,
|
||||||
|
headers={
|
||||||
|
"Authorization": f'Bearer {os.environ["OPENAI_API_KEY"]}',
|
||||||
|
},
|
||||||
|
) as response:
|
||||||
|
return (await response.json())["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def stream_writer(generator: Generator[Segment, Any, None]):
|
||||||
for segment in generator:
|
for segment in generator:
|
||||||
yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
|
yield "data: " + json.dumps(segment, ensure_ascii=False) + "\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
def text_writer(generator: Generator[dict[str, Any], Any, None]):
|
def text_writer(generator: Generator[Segment, Any, None]):
|
||||||
for segment in generator:
|
for segment in generator:
|
||||||
yield segment["text"].strip() + "\n"
|
yield segment.text.strip() + "\n"
|
||||||
|
|
||||||
|
|
||||||
def tsv_writer(generator: Generator[dict[str, Any], Any, None]):
|
def tsv_writer(generator: Generator[Segment, Any, None]):
|
||||||
yield "start\tend\ttext\n"
|
yield "start\tend\ttext\n"
|
||||||
for i, segment in enumerate(generator):
|
for i, segment in enumerate(generator):
|
||||||
start_time = str(round(1000 * segment["start"]))
|
start_time = str(round(1000 * segment.start))
|
||||||
end_time = str(round(1000 * segment["end"]))
|
end_time = str(round(1000 * segment.end))
|
||||||
text = segment["text"].strip()
|
text = segment.text.strip()
|
||||||
yield f"{start_time}\t{end_time}\t{text}\n"
|
yield f"{start_time}\t{end_time}\t{text}\n"
|
||||||
|
|
||||||
|
|
||||||
def srt_writer(generator: Generator[dict[str, Any], Any, None]):
|
def srt_writer(generator: Generator[Segment, Any, None]):
|
||||||
for i, segment in enumerate(generator):
|
for i, segment in enumerate(generator):
|
||||||
start_time = format_timestamp(
|
start_time = format_timestamp(
|
||||||
segment["start"], decimal_marker=",", always_include_hours=True
|
segment.start, decimal_marker=",", always_include_hours=True
|
||||||
)
|
)
|
||||||
end_time = format_timestamp(
|
end_time = format_timestamp(
|
||||||
segment["end"], decimal_marker=",", always_include_hours=True
|
segment.end, decimal_marker=",", always_include_hours=True
|
||||||
)
|
)
|
||||||
text = segment["text"].strip()
|
text = segment.text.strip()
|
||||||
yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
|
yield f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
|
def vtt_writer(generator: Generator[Segment, Any, None]):
|
||||||
yield "WEBVTT\n\n"
|
yield "WEBVTT\n\n"
|
||||||
for i, segment in enumerate(generator):
|
for _, segment in enumerate(generator):
|
||||||
start_time = format_timestamp(segment["start"])
|
start_time = format_timestamp(segment.start)
|
||||||
end_time = format_timestamp(segment["end"])
|
end_time = format_timestamp(segment.end)
|
||||||
text = segment["text"].strip()
|
text = segment.text.strip()
|
||||||
yield f"{start_time} --> {end_time}\n{text}\n\n"
|
yield f"{start_time} --> {end_time}\n{text}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class JsonResult(TranscriptionInfo):
|
||||||
|
segments: list[Segment]
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
def build_json_result(
|
def build_json_result(
|
||||||
generator: Iterable[Segment],
|
generator: Iterable[Segment],
|
||||||
info: dict,
|
info: TranscriptionInfo,
|
||||||
) -> dict[str, Any]:
|
) -> JsonResult:
|
||||||
segments = [i for i in generator]
|
segments = [i for i in generator]
|
||||||
return info | {
|
return JsonResult(
|
||||||
"text": "\n".join(i["text"] for i in segments),
|
text="\n".join(i.text for i in segments),
|
||||||
"segments": segments,
|
segments=segments,
|
||||||
}
|
**dataclasses.asdict(info),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def stream_builder(
|
def stream_builder(
|
||||||
@@ -121,12 +181,13 @@ def stream_builder(
|
|||||||
language: str | None,
|
language: str | None,
|
||||||
initial_prompt: str = "",
|
initial_prompt: str = "",
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
) -> Tuple[Iterable[dict], dict]:
|
) -> Tuple[Generator[Segment, None, None], TranscriptionInfo]:
|
||||||
segments, info = transcriber.model.transcribe(
|
segments, info = model.transcribe(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
language=language,
|
language=language,
|
||||||
task=task,
|
task=task,
|
||||||
initial_prompt=initial_prompt,
|
vad_filter=vad_filter,
|
||||||
|
initial_prompt=initial_prompt if initial_prompt else None,
|
||||||
word_timestamps=True,
|
word_timestamps=True,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
)
|
)
|
||||||
@@ -134,27 +195,14 @@ def stream_builder(
|
|||||||
"Detected language '%s' with probability %f"
|
"Detected language '%s' with probability %f"
|
||||||
% (info.language, info.language_probability)
|
% (info.language, info.language_probability)
|
||||||
)
|
)
|
||||||
def wrap():
|
|
||||||
last_pos = 0
|
|
||||||
with tqdm.tqdm(total=info.duration, unit="seconds", disable=True) as pbar:
|
|
||||||
for segment in segments:
|
|
||||||
start, end, text = segment.start, segment.end, segment.text
|
|
||||||
pbar.update(end - last_pos)
|
|
||||||
last_pos = end
|
|
||||||
data = segment._asdict()
|
|
||||||
if data.get('words') is not None:
|
|
||||||
data["words"] = [i._asdict() for i in data["words"]]
|
|
||||||
if info.language == "zh":
|
|
||||||
data["text"] = ccc.convert(data["text"])
|
|
||||||
yield data
|
|
||||||
|
|
||||||
info_dict = info._asdict()
|
def wrap():
|
||||||
if info_dict['transcription_options'] is not None:
|
for segment in segments:
|
||||||
info_dict['transcription_options'] = info_dict['transcription_options']._asdict()
|
if info.language == "zh":
|
||||||
if info_dict['vad_options'] is not None:
|
segment.text = ccc.convert(segment.text)
|
||||||
info_dict['vad_options'] = info_dict['vad_options']._asdict()
|
yield segment
|
||||||
|
|
||||||
return wrap(), info_dict
|
return wrap(), info
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/k6nele/status")
|
@app.websocket("/k6nele/status")
|
||||||
@@ -171,8 +219,12 @@ async def konele_status(
|
|||||||
|
|
||||||
@app.websocket("/k6nele/ws")
|
@app.websocket("/k6nele/ws")
|
||||||
@app.websocket("/konele/ws")
|
@app.websocket("/konele/ws")
|
||||||
|
@app.websocket("/konele/ws/gpt_refine")
|
||||||
|
@app.websocket("/k6nele/ws/gpt_refine")
|
||||||
@app.websocket("/v1/k6nele/ws")
|
@app.websocket("/v1/k6nele/ws")
|
||||||
@app.websocket("/v1/konele/ws")
|
@app.websocket("/v1/konele/ws")
|
||||||
|
@app.websocket("/v1/konele/ws/gpt_refine")
|
||||||
|
@app.websocket("/v1/k6nele/ws/gpt_refine")
|
||||||
async def konele_ws(
|
async def konele_ws(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
task: Literal["transcribe", "translate"] = "transcribe",
|
task: Literal["transcribe", "translate"] = "transcribe",
|
||||||
@@ -220,15 +272,17 @@ async def konele_ws(
|
|||||||
language=None if lang == "und" else lang,
|
language=None if lang == "und" else lang,
|
||||||
initial_prompt=initial_prompt,
|
initial_prompt=initial_prompt,
|
||||||
)
|
)
|
||||||
result = build_json_result(generator, info)
|
|
||||||
|
|
||||||
text = result.get("text", "")
|
if websocket.url.path.endswith("gpt_refine"):
|
||||||
|
result = await gpt_refine_text(generator, info, initial_prompt)
|
||||||
|
else:
|
||||||
|
result = build_json_result(generator, info).text
|
||||||
|
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"status": 0,
|
"status": 0,
|
||||||
"segment": 0,
|
"segment": 0,
|
||||||
"result": {"hypotheses": [{"transcript": text}], "final": True},
|
"result": {"hypotheses": [{"transcript": result}], "final": True},
|
||||||
"id": md5,
|
"id": md5,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -237,8 +291,12 @@ async def konele_ws(
|
|||||||
|
|
||||||
@app.post("/k6nele/post")
|
@app.post("/k6nele/post")
|
||||||
@app.post("/konele/post")
|
@app.post("/konele/post")
|
||||||
|
@app.post("/k6nele/post/gpt_refine")
|
||||||
|
@app.post("/konele/post/gpt_refine")
|
||||||
@app.post("/v1/k6nele/post")
|
@app.post("/v1/k6nele/post")
|
||||||
@app.post("/v1/konele/post")
|
@app.post("/v1/konele/post")
|
||||||
|
@app.post("/v1/k6nele/post/gpt_refine")
|
||||||
|
@app.post("/v1/konele/post/gpt_refine")
|
||||||
async def translateapi(
|
async def translateapi(
|
||||||
request: Request,
|
request: Request,
|
||||||
task: Literal["transcribe", "translate"] = "transcribe",
|
task: Literal["transcribe", "translate"] = "transcribe",
|
||||||
@@ -283,37 +341,51 @@ async def translateapi(
|
|||||||
language=None if lang == "und" else lang,
|
language=None if lang == "und" else lang,
|
||||||
initial_prompt=initial_prompt,
|
initial_prompt=initial_prompt,
|
||||||
)
|
)
|
||||||
result = build_json_result(generator, info)
|
|
||||||
|
|
||||||
text = result.get("text", "")
|
if request.url.path.endswith("gpt_refine"):
|
||||||
|
result = await gpt_refine_text(generator, info, initial_prompt)
|
||||||
|
else:
|
||||||
|
result = build_json_result(generator, info).text
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": 0,
|
"status": 0,
|
||||||
"hypotheses": [{"utterance": text}],
|
"hypotheses": [{"utterance": result}],
|
||||||
"id": md5,
|
"id": md5,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/audio/transcriptions")
|
@app.post("/v1/audio/transcriptions", response_model=Union[JsonResult, str])
|
||||||
|
@app.post("/v1/audio/translations", response_model=Union[JsonResult, str])
|
||||||
async def transcription(
|
async def transcription(
|
||||||
|
request: Request,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
prompt: str = Form(""),
|
prompt: str = Form(""),
|
||||||
response_format: str = Form("json"),
|
response_format: str = Form("json"),
|
||||||
task: str = Form("transcribe"),
|
task: str = Form(""),
|
||||||
language: str = Form("und"),
|
language: str = Form("und"),
|
||||||
vad_filter: bool = Form(False),
|
vad_filter: bool = Form(False),
|
||||||
repetition_penalty: float = Form(1.0),
|
repetition_penalty: float = Form(1.0),
|
||||||
|
gpt_refine: bool = Form(False),
|
||||||
):
|
):
|
||||||
"""Transcription endpoint
|
"""Transcription endpoint
|
||||||
|
|
||||||
User upload audio file in multipart/form-data format and receive transcription in response
|
User upload audio file in multipart/form-data format and receive transcription in response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
if request.url.path == "/v1/audio/transcriptions":
|
||||||
|
task = "transcribe"
|
||||||
|
elif request.url.path == "/v1/audio/translations":
|
||||||
|
task = "translate"
|
||||||
|
else:
|
||||||
|
raise HTTPException(400, "task parameter is required")
|
||||||
|
|
||||||
# timestamp as filename, keep original extension
|
# timestamp as filename, keep original extension
|
||||||
generator, info = stream_builder(
|
generator, info = stream_builder(
|
||||||
audio=io.BytesIO(file.file.read()),
|
audio=io.BytesIO(file.file.read()),
|
||||||
task=task,
|
task=task,
|
||||||
vad_filter=vad_filter,
|
vad_filter=vad_filter,
|
||||||
|
initial_prompt=prompt,
|
||||||
language=None if language == "und" else language,
|
language=None if language == "und" else language,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
)
|
)
|
||||||
@@ -327,6 +399,8 @@ async def transcription(
|
|||||||
elif response_format == "json":
|
elif response_format == "json":
|
||||||
return build_json_result(generator, info)
|
return build_json_result(generator, info)
|
||||||
elif response_format == "text":
|
elif response_format == "text":
|
||||||
|
if gpt_refine:
|
||||||
|
return PlainTextResponse(await gpt_refine_text(generator, info, prompt))
|
||||||
return StreamingResponse(text_writer(generator), media_type="text/plain")
|
return StreamingResponse(text_writer(generator), media_type="text/plain")
|
||||||
elif response_format == "tsv":
|
elif response_format == "tsv":
|
||||||
return StreamingResponse(tsv_writer(generator), media_type="text/plain")
|
return StreamingResponse(tsv_writer(generator), media_type="text/plain")
|
||||||
|
|||||||
Reference in New Issue
Block a user