Compare commits
4 Commits
v1.0.3
...
0faaf0f301
| Author | SHA1 | Date | |
|---|---|---|---|
|
0faaf0f301
|
|||
|
fab1ec9d03
|
|||
|
71bde08b17
|
|||
|
a53a2fb80e
|
@@ -104,7 +104,7 @@ def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
|
|||||||
|
|
||||||
|
|
||||||
def build_json_result(
|
def build_json_result(
|
||||||
generator: Iterable[Segment],
|
generator: Iterable[dict],
|
||||||
info: dict,
|
info: dict,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
segments = [i for i in generator]
|
segments = [i for i in generator]
|
||||||
@@ -121,12 +121,13 @@ def stream_builder(
|
|||||||
language: str | None,
|
language: str | None,
|
||||||
initial_prompt: str = "",
|
initial_prompt: str = "",
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
) -> Tuple[Iterable[dict], dict]:
|
) -> Tuple[Generator[dict, None, None], dict]:
|
||||||
segments, info = transcriber.model.transcribe(
|
segments, info = transcriber.model.transcribe(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
language=language,
|
language=language,
|
||||||
task=task,
|
task=task,
|
||||||
initial_prompt=initial_prompt,
|
vad_filter=vad_filter,
|
||||||
|
initial_prompt=initial_prompt if initial_prompt else None,
|
||||||
word_timestamps=True,
|
word_timestamps=True,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
)
|
)
|
||||||
@@ -295,11 +296,13 @@ async def translateapi(
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/audio/transcriptions")
|
@app.post("/v1/audio/transcriptions")
|
||||||
|
@app.post("/v1/audio/translations")
|
||||||
async def transcription(
|
async def transcription(
|
||||||
|
request: Request,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
prompt: str = Form(""),
|
prompt: str = Form(""),
|
||||||
response_format: str = Form("json"),
|
response_format: str = Form("json"),
|
||||||
task: str = Form("transcribe"),
|
task: str = Form(""),
|
||||||
language: str = Form("und"),
|
language: str = Form("und"),
|
||||||
vad_filter: bool = Form(False),
|
vad_filter: bool = Form(False),
|
||||||
repetition_penalty: float = Form(1.0),
|
repetition_penalty: float = Form(1.0),
|
||||||
@@ -309,11 +312,20 @@ async def transcription(
|
|||||||
User upload audio file in multipart/form-data format and receive transcription in response
|
User upload audio file in multipart/form-data format and receive transcription in response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
if request.url.path == '/v1/audio/transcriptions':
|
||||||
|
task = "transcribe"
|
||||||
|
elif request.url.path == '/v1/audio/translations':
|
||||||
|
task = "translate"
|
||||||
|
else:
|
||||||
|
raise HTTPException(400, "task parameter is required")
|
||||||
|
|
||||||
# timestamp as filename, keep original extension
|
# timestamp as filename, keep original extension
|
||||||
generator, info = stream_builder(
|
generator, info = stream_builder(
|
||||||
audio=io.BytesIO(file.file.read()),
|
audio=io.BytesIO(file.file.read()),
|
||||||
task=task,
|
task=task,
|
||||||
vad_filter=vad_filter,
|
vad_filter=vad_filter,
|
||||||
|
initial_prompt=prompt,
|
||||||
language=None if language == "und" else language,
|
language=None if language == "und" else language,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user