From 2b53dee6b697ba1364d7ba1b1c1ce3a6a2dc05df Mon Sep 17 00:00:00 2001 From: Ewald Enzinger Date: Sat, 8 Apr 2023 10:02:36 +0200 Subject: [PATCH] Expose download location in WhisperModel constructor (#126) This increases compatibility with OpenAI Whisper's whisper.load_model() and is useful for downstream integrations --- faster_whisper/transcribe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index b077d8b..6d31271 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -72,6 +72,7 @@ class WhisperModel: compute_type: str = "default", cpu_threads: int = 0, num_workers: int = 1, + download_root: Optional[str] = None, ): """Initializes the Whisper model. @@ -93,13 +94,15 @@ class WhisperModel: having multiple workers enables true parallelism when running the model (concurrent calls to self.model.generate() will run in parallel). This can improve the global throughput at the cost of increased memory usage. + download_root: Directory where the model should be saved. If not set, the model + is saved in the standard Hugging Face cache directory. """ self.logger = get_logger() if os.path.isdir(model_size_or_path): model_path = model_size_or_path else: - model_path = download_model(model_size_or_path) + model_path = download_model(model_size_or_path, download_root) self.model = ctranslate2.models.Whisper( model_path,