Expose download location in WhisperModel constructor (#126)

This increases compatibility with OpenAI Whisper's whisper.load_model() and is useful for downstream integrations
This commit is contained in:
Ewald Enzinger
2023-04-08 10:02:36 +02:00
committed by GitHub
parent 06d24056e9
commit 2b53dee6b6

View File

@@ -72,6 +72,7 @@ class WhisperModel:
compute_type: str = "default",
cpu_threads: int = 0,
num_workers: int = 1,
download_root: Optional[str] = None,
):
"""Initializes the Whisper model.
@@ -93,13 +94,15 @@ class WhisperModel:
having multiple workers enables true parallelism when running the model
(concurrent calls to self.model.generate() will run in parallel).
This can improve the global throughput at the cost of increased memory usage.
download_root: Directory where the model should be saved. If not set, the model
is saved in the standard Hugging Face cache directory.
"""
self.logger = get_logger()
if os.path.isdir(model_size_or_path):
model_path = model_size_or_path
else:
model_path = download_model(model_size_or_path)
model_path = download_model(model_size_or_path, download_root)
self.model = ctranslate2.models.Whisper(
model_path,