diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index a54b9c6..4d7f0f6 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -112,7 +112,9 @@ class WhisperModel: model_path = model_size_or_path else: model_path = download_model( - model_size_or_path, download_root, local_files_only + model_size_or_path, + local_files_only=local_files_only, + cache_dir=download_root, ) self.model = ctranslate2.models.Whisper( diff --git a/faster_whisper/utils.py b/faster_whisper/utils.py index 649906a..34a310a 100644 --- a/faster_whisper/utils.py +++ b/faster_whisper/utils.py @@ -35,6 +35,7 @@ def download_model( size: str, output_dir: Optional[str] = None, local_files_only: Optional[bool] = False, + cache_dir: Optional[str] = None, ): """Downloads a CTranslate2 Whisper model from the Hugging Face Hub. @@ -47,6 +48,7 @@ def download_model( the standard Hugging Face cache directory. local_files_only: If True, avoid downloading the file and return the path to the local cached file if it exists. + cache_dir: Path to the folder where cached files are stored. Returns: The path to the downloaded model. @@ -66,6 +68,9 @@ def download_model( kwargs["local_dir"] = output_dir kwargs["local_dir_use_symlinks"] = False + if cache_dir is not None: + kwargs["cache_dir"] = cache_dir + allow_patterns = [ "config.json", "model.bin", diff --git a/tests/test_utils.py b/tests/test_utils.py index 3e981f6..ee404bf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,3 +15,9 @@ def test_download_model(tmpdir): for filename in os.listdir(model_dir): path = os.path.join(model_dir, filename) assert not os.path.islink(path) + + +def test_download_model_in_cache(tmpdir): + cache_dir = str(tmpdir.join("model")) + download_model("tiny", cache_dir=cache_dir) + assert os.path.isdir(cache_dir)