Compare commits

..

4 Commits

Author SHA1 Message Date
0faaf0f301 support translate endpoint 2024-09-13 16:21:35 +08:00
fab1ec9d03 fix: initial_prompt params 2024-09-13 16:13:19 +08:00
71bde08b17 fix: vad_filter params 2024-09-13 16:10:38 +08:00
a53a2fb80e fix typing hint 2024-09-13 16:09:37 +08:00

View File

@@ -104,7 +104,7 @@ def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
def build_json_result( def build_json_result(
generator: Iterable[Segment], generator: Iterable[dict],
info: dict, info: dict,
) -> dict[str, Any]: ) -> dict[str, Any]:
segments = [i for i in generator] segments = [i for i in generator]
@@ -121,12 +121,13 @@ def stream_builder(
language: str | None, language: str | None,
initial_prompt: str = "", initial_prompt: str = "",
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
) -> Tuple[Iterable[dict], dict]: ) -> Tuple[Generator[dict, None, None], dict]:
segments, info = transcriber.model.transcribe( segments, info = transcriber.model.transcribe(
audio=audio, audio=audio,
language=language, language=language,
task=task, task=task,
initial_prompt=initial_prompt, vad_filter=vad_filter,
initial_prompt=initial_prompt if initial_prompt else None,
word_timestamps=True, word_timestamps=True,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
) )
@@ -295,11 +296,13 @@ async def translateapi(
@app.post("/v1/audio/transcriptions") @app.post("/v1/audio/transcriptions")
@app.post("/v1/audio/translations")
async def transcription( async def transcription(
request: Request,
file: UploadFile = File(...), file: UploadFile = File(...),
prompt: str = Form(""), prompt: str = Form(""),
response_format: str = Form("json"), response_format: str = Form("json"),
task: str = Form("transcribe"), task: str = Form(""),
language: str = Form("und"), language: str = Form("und"),
vad_filter: bool = Form(False), vad_filter: bool = Form(False),
repetition_penalty: float = Form(1.0), repetition_penalty: float = Form(1.0),
@@ -309,11 +312,20 @@ async def transcription(
User upload audio file in multipart/form-data format and receive transcription in response User upload audio file in multipart/form-data format and receive transcription in response
""" """
if not task:
if request.url.path == '/v1/audio/transcriptions':
task = "transcribe"
elif request.url.path == '/v1/audio/translations':
task = "translate"
else:
raise HTTPException(400, "task parameter is required")
# timestamp as filename, keep original extension # timestamp as filename, keep original extension
generator, info = stream_builder( generator, info = stream_builder(
audio=io.BytesIO(file.file.read()), audio=io.BytesIO(file.file.read()),
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
initial_prompt=prompt,
language=None if language == "und" else language, language=None if language == "und" else language,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
) )