add: support for flac

This commit is contained in:
2023-11-15 17:27:50 +08:00
parent c3838dcb3f
commit e403a514ff

View File

@@ -1,10 +1,11 @@
import wave import wave
import pydub
import io import io
import hashlib import hashlib
import argparse import argparse
import uvicorn import uvicorn
from typing import Any, Literal from typing import Annotated, Any, Literal
from fastapi import File, UploadFile, Form, FastAPI, Request, WebSocket, Response from fastapi import File, Query, UploadFile, Form, FastAPI, Request, WebSocket, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe, TranscriptionOptions
from src.whisper_ctranslate2.writers import format_timestamp from src.whisper_ctranslate2.writers import format_timestamp
@@ -126,9 +127,15 @@ async def konele_ws(
task: Literal["transcribe", "translate"] = "transcribe", task: Literal["transcribe", "translate"] = "transcribe",
lang: str = "und", lang: str = "und",
initial_prompt: str = "", initial_prompt: str = "",
content_type: Annotated[str, Query(alias="content-type")] = "audio/x-raw",
): ):
await websocket.accept() await websocket.accept()
# convert lang code format (eg. en-US to en)
lang = lang.split("-")[0]
print("WebSocket client connected, lang is", lang) print("WebSocket client connected, lang is", lang)
print("content-type is", content_type)
data = b"" data = b""
while True: while True:
try: try:
@@ -145,11 +152,17 @@ async def konele_ws(
# create fake file for wave.open # create fake file for wave.open
file_obj = io.BytesIO() file_obj = io.BytesIO()
if content_type.startswith("audio/x-flac"):
pydub.AudioSegment.from_file(io.BytesIO(data), format="flac").export(
file_obj, format="wav"
)
else:
buffer = wave.open(file_obj, "wb") buffer = wave.open(file_obj, "wb")
buffer.setnchannels(1) buffer.setnchannels(1)
buffer.setsampwidth(2) buffer.setsampwidth(2)
buffer.setframerate(16000) buffer.setframerate(16000)
buffer.writeframes(data) buffer.writeframes(data)
file_obj.seek(0) file_obj.seek(0)
options = get_options(initial_prompt=initial_prompt) options = get_options(initial_prompt=initial_prompt)
@@ -187,6 +200,10 @@ async def translateapi(
): ):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
print("downloading request file", content_type) print("downloading request file", content_type)
# convert lang code format (eg. en-US to en)
lang = lang.split("-")[0]
splited = [i.strip() for i in content_type.split(",") if "=" in i] splited = [i.strip() for i in content_type.split(",") if "=" in i]
info = {k: v for k, v in (i.split("=") for i in splited)} info = {k: v for k, v in (i.split("=") for i in splited)}
print(info) print(info)
@@ -200,11 +217,17 @@ async def translateapi(
# create fake file for wave.open # create fake file for wave.open
file_obj = io.BytesIO() file_obj = io.BytesIO()
if content_type.startswith("audio/x-flac"):
pydub.AudioSegment.from_file(io.BytesIO(body), format="flac").export(
file_obj, format="wav"
)
else:
buffer = wave.open(file_obj, "wb") buffer = wave.open(file_obj, "wb")
buffer.setnchannels(channels) buffer.setnchannels(channels)
buffer.setsampwidth(2) buffer.setsampwidth(2)
buffer.setframerate(rate) buffer.setframerate(rate)
buffer.writeframes(body) buffer.writeframes(body)
file_obj.seek(0) file_obj.seek(0)
options = get_options(initial_prompt=initial_prompt) options = get_options(initial_prompt=initial_prompt)