Allow specifying local_files_only to prevent checking the Internet everytime (#166)

This commit is contained in:
Jordi Mas
2023-04-20 14:26:06 +02:00
committed by GitHub
parent 3adcc12d0f
commit 358d373691
2 changed files with 14 additions and 3 deletions

View File

@@ -73,6 +73,7 @@ class WhisperModel:
cpu_threads: int = 0, cpu_threads: int = 0,
num_workers: int = 1, num_workers: int = 1,
download_root: Optional[str] = None, download_root: Optional[str] = None,
local_files_only: Optional[bool] = False,
): ):
"""Initializes the Whisper model. """Initializes the Whisper model.
@@ -96,13 +97,17 @@ class WhisperModel:
This can improve the global throughput at the cost of increased memory usage. This can improve the global throughput at the cost of increased memory usage.
download_root: Directory where the model should be saved. If not set, the model download_root: Directory where the model should be saved. If not set, the model
is saved in the standard Hugging Face cache directory. is saved in the standard Hugging Face cache directory.
local_files_only: If True, avoid downloading the file and return the path to the
local cached file if it exists.
""" """
self.logger = get_logger() self.logger = get_logger()
if os.path.isdir(model_size_or_path): if os.path.isdir(model_size_or_path):
model_path = model_size_or_path model_path = model_size_or_path
else: else:
model_path = download_model(model_size_or_path, download_root) model_path = download_model(
model_size_or_path, download_root, local_files_only
)
self.model = ctranslate2.models.Whisper( self.model = ctranslate2.models.Whisper(
model_path, model_path,

View File

@@ -31,7 +31,11 @@ def get_logger():
return logging.getLogger("faster_whisper") return logging.getLogger("faster_whisper")
def download_model(size: str, output_dir: Optional[str] = None): def download_model(
size: str,
output_dir: Optional[str] = None,
local_files_only: Optional[bool] = False,
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub. """Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
The model is downloaded from https://huggingface.co/guillaumekln. The model is downloaded from https://huggingface.co/guillaumekln.
@@ -41,6 +45,8 @@ def download_model(size: str, output_dir: Optional[str] = None):
medium, medium.en, large-v1, or large-v2). medium, medium.en, large-v1, or large-v2).
output_dir: Directory where the model should be saved. If not set, the model is saved in output_dir: Directory where the model should be saved. If not set, the model is saved in
the standard Hugging Face cache directory. the standard Hugging Face cache directory.
local_files_only: If True, avoid downloading the file and return the path to the local
cached file if it exists.
Returns: Returns:
The path to the downloaded model. The path to the downloaded model.
@@ -55,7 +61,7 @@ def download_model(size: str, output_dir: Optional[str] = None):
repo_id = "guillaumekln/faster-whisper-%s" % size repo_id = "guillaumekln/faster-whisper-%s" % size
kwargs = {} kwargs = {}
kwargs["local_files_only"] = local_files_only
if output_dir is not None: if output_dir is not None:
kwargs["local_dir"] = output_dir kwargs["local_dir"] = output_dir
kwargs["local_dir_use_symlinks"] = False kwargs["local_dir_use_symlinks"] = False