From 62b7431655f2f5f020d8b28f01c6fb3bce6328af Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Wed, 15 Nov 2023 16:19:20 +0800 Subject: [PATCH] add: option of task and lang --- whisper_fastapi.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/whisper_fastapi.py b/whisper_fastapi.py index 29ab0ab..e058ba6 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -3,7 +3,7 @@ import io import hashlib import argparse import uvicorn -from typing import Any +from typing import Any, Literal from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket, Response from fastapi.middleware.cors import CORSMiddleware from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions @@ -112,6 +112,7 @@ def get_options(*, initial_prompt=""): @app.websocket("/konele/ws") async def konele_ws( websocket: WebSocket, + task: Literal["transcribe", "translate"] = "transcribe", lang: str = "und", ): await websocket.accept() @@ -143,9 +144,8 @@ async def konele_ws( result = transcriber.inference( audio=file_obj, - # Enter translate mode if target language is English - task="translate" if lang == "en-US" else "transcribe", - language=None, # type: ignore + task=task, + language=lang if lang != "und" else None, # type: ignore verbose=False, live=False, options=options, @@ -168,6 +168,7 @@ async def konele_ws( @app.post("/konele/post") async def translateapi( request: Request, + task: Literal["transcribe", "translate"] = "transcribe", lang: str = "und", ): content_type = request.headers.get("Content-Type", "") @@ -196,9 +197,8 @@ async def translateapi( result = transcriber.inference( audio=file_obj, - # Enter translate mode if target language is English - task="translate" if lang == "en-US" else "transcribe", - language=None, # type: ignore + task=task, + language=lang if lang != "und" else None, # type: ignore verbose=False, live=False, options=options,