disable zh-hans to zh-cn convertion

This commit is contained in:
2024-07-10 18:51:29 +08:00
parent bc5fdec819
commit 6decabefae

View File

@@ -22,7 +22,6 @@ from fastapi.middleware.cors import CORSMiddleware
from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe from src.whisper_ctranslate2.whisper_ctranslate2 import Transcribe
from src.whisper_ctranslate2.writers import format_timestamp from src.whisper_ctranslate2.writers import format_timestamp
from faster_whisper.transcribe import Segment, TranscriptionInfo from faster_whisper.transcribe import Segment, TranscriptionInfo
import opencc
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -35,7 +34,6 @@ args = parser.parse_args()
app = FastAPI() app = FastAPI()
# Instrument your app with default metrics and expose the metrics # Instrument your app with default metrics and expose the metrics
Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics") Instrumentator().instrument(app).expose(app, endpoint="/konele/metrics")
ccc = opencc.OpenCC("t2s.json")
print("Loading model...") print("Loading model...")
transcriber = Transcribe( transcriber = Transcribe(
@@ -102,12 +100,15 @@ def vtt_writer(generator: Generator[dict[str, Any], Any, None]):
def build_json_result( def build_json_result(
generator: Generator[dict[str, Any], Any, None] generator: Iterable[Segment],
info: TranscriptionInfo,
) -> dict[str, Any]: ) -> dict[str, Any]:
segments = [i for i in generator] segments = [i for i in generator]
return { return {
"text": "\n".join(i["text"] for i in segments), "text": "\n".join(i["text"] for i in segments),
"segments": segments, "segments": segments,
"language": info.language,
"language_probability": info.language_probability,
} }
@@ -137,7 +138,6 @@ def stream_builder(
last_pos = end last_pos = end
data = segment._asdict() data = segment._asdict()
data["total"] = info.duration data["total"] = info.duration
data["text"] = ccc.convert(data["text"])
yield data yield data
return wrap(), info return wrap(), info
@@ -199,14 +199,14 @@ async def konele_ws(
file_obj.seek(0) file_obj.seek(0)
generator = stream_builder( generator, info = stream_builder(
audio=file_obj, audio=file_obj,
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
language=None if lang == "und" else lang, language=None if lang == "und" else lang,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
) )
result = build_json_result(generator) result = build_json_result(generator, info)
text = result.get("text", "") text = result.get("text", "")
print("result", text) print("result", text)
@@ -263,14 +263,14 @@ async def translateapi(
file_obj.seek(0) file_obj.seek(0)
generator = stream_builder( generator, info = stream_builder(
audio=file_obj, audio=file_obj,
task=task, task=task,
vad_filter=vad_filter, vad_filter=vad_filter,
language=None if lang == "und" else lang, language=None if lang == "und" else lang,
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
) )
result = build_json_result(generator) result = build_json_result(generator, info)
text = result.get("text", "") text = result.get("text", "")
print("result", text) print("result", text)
@@ -311,7 +311,7 @@ async def transcription(
media_type="text/event-stream", media_type="text/event-stream",
) )
elif response_format == "json": elif response_format == "json":
return build_json_result(generator) return build_json_result(generator, info)
elif response_format == "text": elif response_format == "text":
return StreamingResponse(text_writer(generator), media_type="text/plain") return StreamingResponse(text_writer(generator), media_type="text/plain")
elif response_format == "tsv": elif response_format == "tsv":