diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index bab6d44..911928f 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -68,9 +68,10 @@ class WhisperModel: """Initializes the Whisper model. Args: - model_size_or_path: Size of the model to use (e.g. "large-v2", "small", "tiny.en", etc.) - or a path to a converted model directory. When a size is configured, the converted - model is downloaded from the Hugging Face Hub. + model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, + small, small.en, medium, medium.en, or large-v2) or a path to a converted + model directory. When a size is configured, the converted model is downloaded + from the Hugging Face Hub. device: Device to use for computation ("cpu", "cuda", "auto"). device_index: Device ID to use. The model can also be loaded on multiple GPUs by passing a list of IDs diff --git a/faster_whisper/utils.py b/faster_whisper/utils.py index bfd2f17..b97c9a1 100644 --- a/faster_whisper/utils.py +++ b/faster_whisper/utils.py @@ -4,6 +4,18 @@ import huggingface_hub from tqdm.auto import tqdm +_MODELS = ( + "tiny.en", + "tiny", + "base.en", + "base", + "small.en", + "small", + "medium.en", + "medium", + "large-v2", +) + def download_model( size: str, @@ -23,7 +35,15 @@ def download_model( Returns: The path to the downloaded model. + + Raises: + ValueError: if the model size is invalid. """ + if size not in _MODELS: + raise ValueError( + "Invalid model size '%s', expected one of: %s" % (size, ", ".join(_MODELS)) + ) + repo_id = "guillaumekln/faster-whisper-%s" % size kwargs = {}