Add "large" alias for "large-v2" model (#453)

This commit is contained in:
Guillaume Klein
2023-09-04 11:54:42 +02:00
committed by GitHub
parent f0ff12965a
commit 1e6eb967c9
2 changed files with 21 additions and 22 deletions

View File

@@ -92,7 +92,7 @@ class WhisperModel:
Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
small, small.en, medium, medium.en, large-v1, or large-v2), a path to a converted
small, small.en, medium, medium.en, large-v1, large-v2, or large), a path to a converted
model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub.
When a size or a model ID is configured, the converted model is downloaded
from the Hugging Face Hub.

View File

@@ -9,18 +9,19 @@ import requests
from tqdm.auto import tqdm
_MODELS = (
"tiny.en",
"tiny",
"base.en",
"base",
"small.en",
"small",
"medium.en",
"medium",
"large-v1",
"large-v2",
)
_MODELS = {
"tiny.en": "guillaumekln/faster-whisper-tiny.en",
"tiny": "guillaumekln/faster-whisper-tiny",
"base.en": "guillaumekln/faster-whisper-base.en",
"base": "guillaumekln/faster-whisper-base",
"small.en": "guillaumekln/faster-whisper-small.en",
"small": "guillaumekln/faster-whisper-small",
"medium.en": "guillaumekln/faster-whisper-medium.en",
"medium": "guillaumekln/faster-whisper-medium",
"large-v1": "guillaumekln/faster-whisper-large-v1",
"large-v2": "guillaumekln/faster-whisper-large-v2",
"large": "guillaumekln/faster-whisper-large-v2",
}
def get_assets_path():
@@ -41,12 +42,11 @@ def download_model(
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
The model is downloaded from https://huggingface.co/guillaumekln.
Args:
size_or_id: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
medium, medium.en, large-v1, or large-v2), or a CTranslate2-converted model ID
from the Hugging Face Hub (e.g. guillaumekln/faster-whisper-large-v2).
size_or_id: Size of the model to download from https://huggingface.co/guillaumekln
(tiny, tiny.en, base, base.en, small, small.en medium, medium.en, large-v1, large-v2,
large), or a CTranslate2-converted model ID from the Hugging Face Hub
(e.g. guillaumekln/faster-whisper-large-v2).
output_dir: Directory where the model should be saved. If not set, the model is saved in
the cache directory.
local_files_only: If True, avoid downloading the file and return the path to the local
@@ -62,14 +62,13 @@ def download_model(
if re.match(r".*/.*", size_or_id):
repo_id = size_or_id
else:
if size_or_id not in _MODELS:
repo_id = _MODELS.get(size_or_id)
if repo_id is None:
raise ValueError(
"Invalid model size '%s', expected one of: %s"
% (size_or_id, ", ".join(_MODELS))
% (size_or_id, ", ".join(_MODELS.keys()))
)
repo_id = "guillaumekln/faster-whisper-%s" % size_or_id
allow_patterns = [
"config.json",
"model.bin",