diff --git a/whisper_fastapi.py b/whisper_fastapi.py index 38be547..e4075ae 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -125,6 +125,7 @@ async def konele_ws( websocket: WebSocket, task: Literal["transcribe", "translate"] = "transcribe", lang: str = "und", + initial_prompt: str = "", ): await websocket.accept() print("WebSocket client connected, lang is", lang) @@ -151,7 +152,7 @@ async def konele_ws( buffer.writeframes(data) file_obj.seek(0) - options = get_options() + options = get_options(initial_prompt=initial_prompt) result = transcriber.inference( audio=file_obj, @@ -182,6 +183,7 @@ async def translateapi( request: Request, task: Literal["transcribe", "translate"] = "transcribe", lang: str = "und", + initial_prompt: str = "", ): content_type = request.headers.get("Content-Type", "") print("downloading request file", content_type) @@ -205,7 +207,7 @@ async def translateapi( buffer.writeframes(body) file_obj.seek(0) - options = get_options() + options = get_options(initial_prompt=initial_prompt) result = transcriber.inference( audio=file_obj,