diff --git a/whisper_fastapi.py b/whisper_fastapi.py index c62fe7f..b3a5c0d 100644 --- a/whisper_fastapi.py +++ b/whisper_fastapi.py @@ -16,6 +16,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", type=str) parser.add_argument("--port", default=5000, type=int) parser.add_argument("--model", default="large-v2", type=str) +parser.add_argument("--device", default="auto", type=str) parser.add_argument("--cache_dir", default=None, type=str) args = parser.parse_args() app = FastAPI() @@ -26,7 +27,7 @@ ccc = opencc.OpenCC("t2s.json") print("Loading model...") transcriber = Transcribe( model_path=args.model, - device="auto", + device=args.device, device_index=0, compute_type="default", threads=1, @@ -91,7 +92,7 @@ def get_options(*, initial_prompt=""): compression_ratio_threshold=2.4, condition_on_previous_text=True, temperature=[0.0, 1.0 + 1e-6, 0.2], - suppress_tokens=[-1], + suppress_tokens=[], word_timestamps=True, print_colors=False, prepend_punctuations="\"'“¿([{-",