Raise an explicit error message if the model size is invalid
This commit is contained in:
@@ -68,9 +68,10 @@ class WhisperModel:
|
|||||||
"""Initializes the Whisper model.
|
"""Initializes the Whisper model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_size_or_path: Size of the model to use (e.g. "large-v2", "small", "tiny.en", etc.)
|
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
|
||||||
or a path to a converted model directory. When a size is configured, the converted
|
small, small.en, medium, medium.en, or large-v2) or a path to a converted
|
||||||
model is downloaded from the Hugging Face Hub.
|
model directory. When a size is configured, the converted model is downloaded
|
||||||
|
from the Hugging Face Hub.
|
||||||
device: Device to use for computation ("cpu", "cuda", "auto").
|
device: Device to use for computation ("cpu", "cuda", "auto").
|
||||||
device_index: Device ID to use.
|
device_index: Device ID to use.
|
||||||
The model can also be loaded on multiple GPUs by passing a list of IDs
|
The model can also be loaded on multiple GPUs by passing a list of IDs
|
||||||
|
|||||||
@@ -4,6 +4,18 @@ import huggingface_hub
|
|||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
_MODELS = (
|
||||||
|
"tiny.en",
|
||||||
|
"tiny",
|
||||||
|
"base.en",
|
||||||
|
"base",
|
||||||
|
"small.en",
|
||||||
|
"small",
|
||||||
|
"medium.en",
|
||||||
|
"medium",
|
||||||
|
"large-v2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_model(
|
def download_model(
|
||||||
size: str,
|
size: str,
|
||||||
@@ -23,7 +35,15 @@ def download_model(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The path to the downloaded model.
|
The path to the downloaded model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the model size is invalid.
|
||||||
"""
|
"""
|
||||||
|
if size not in _MODELS:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid model size '%s', expected one of: %s" % (size, ", ".join(_MODELS))
|
||||||
|
)
|
||||||
|
|
||||||
repo_id = "guillaumekln/faster-whisper-%s" % size
|
repo_id = "guillaumekln/faster-whisper-%s" % size
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user