Use cache_dir instead of local_dir (#182)

* Use cache_dir instead of local_dir

* Fix unit test

* Use cache_dir and preserve local_dir parameter

* Remove blank line at the end

* Disable ut

* Implement  download_root suggestion

* Use cache_dir=download_root
This commit is contained in:
Jordi Mas
2023-04-26 16:35:18 +02:00
committed by GitHub
parent 67cce3f552
commit 68df3214ba
3 changed files with 14 additions and 1 deletions

View File

@@ -112,7 +112,9 @@ class WhisperModel:
model_path = model_size_or_path model_path = model_size_or_path
else: else:
model_path = download_model( model_path = download_model(
model_size_or_path, download_root, local_files_only model_size_or_path,
local_files_only=local_files_only,
cache_dir=download_root,
) )
self.model = ctranslate2.models.Whisper( self.model = ctranslate2.models.Whisper(

View File

@@ -35,6 +35,7 @@ def download_model(
size: str, size: str,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
local_files_only: Optional[bool] = False, local_files_only: Optional[bool] = False,
cache_dir: Optional[str] = None,
): ):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub. """Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
@@ -47,6 +48,7 @@ def download_model(
the standard Hugging Face cache directory. the standard Hugging Face cache directory.
local_files_only: If True, avoid downloading the file and return the path to the local local_files_only: If True, avoid downloading the file and return the path to the local
cached file if it exists. cached file if it exists.
cache_dir: Path to the folder where cached files are stored.
Returns: Returns:
The path to the downloaded model. The path to the downloaded model.
@@ -66,6 +68,9 @@ def download_model(
kwargs["local_dir"] = output_dir kwargs["local_dir"] = output_dir
kwargs["local_dir_use_symlinks"] = False kwargs["local_dir_use_symlinks"] = False
if cache_dir is not None:
kwargs["cache_dir"] = cache_dir
allow_patterns = [ allow_patterns = [
"config.json", "config.json",
"model.bin", "model.bin",

View File

@@ -15,3 +15,9 @@ def test_download_model(tmpdir):
for filename in os.listdir(model_dir): for filename in os.listdir(model_dir):
path = os.path.join(model_dir, filename) path = os.path.join(model_dir, filename)
assert not os.path.islink(path) assert not os.path.islink(path)
def test_download_model_in_cache(tmpdir):
cache_dir = str(tmpdir.join("model"))
download_model("tiny", cache_dir=cache_dir)
assert os.path.isdir(cache_dir)