diff --git a/README.md b/README.md index daee860..54d7445 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,18 @@ ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper- Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html). +### Load a converted model + +1. Directly load the model from a local directory: +```python +model = faster_whisper.WhisperModel('whisper-large-v2-ct2') +``` + +2. [Upload your model to the Hugging Face Hub](https://huggingface.co/docs/transformers/model_sharing#upload-with-the-web-interface) and load it from its name: +```python +model = faster_whisper.WhisperModel('username/whisper-large-v2-ct2') +``` + ## Comparing performance against other implementations If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular: diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 017d398..cfb2e8a 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -88,8 +88,9 @@ class WhisperModel: Args: model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, - small, small.en, medium, medium.en, large-v1, or large-v2) or a path to a converted - model directory. When a size is configured, the converted model is downloaded + small, small.en, medium, medium.en, large-v1, or large-v2), a path to a converted + model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub. + When a size or a model ID is configured, the converted model is downloaded from the Hugging Face Hub. device: Device to use for computation ("cpu", "cuda", "auto"). device_index: Device ID to use. diff --git a/faster_whisper/utils.py b/faster_whisper/utils.py index 950b0da..e86b89e 100644 --- a/faster_whisper/utils.py +++ b/faster_whisper/utils.py @@ -1,5 +1,6 @@ import logging import os +import re from typing import Optional @@ -33,7 +34,7 @@ def get_logger(): def download_model( - size: str, + size_or_id: str, output_dir: Optional[str] = None, local_files_only: bool = False, cache_dir: Optional[str] = None, @@ -43,8 +44,9 @@ def download_model( The model is downloaded from https://huggingface.co/guillaumekln. Args: - size: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en, - medium, medium.en, large-v1, or large-v2). + size_or_id: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en, + medium, medium.en, large-v1, or large-v2), or a CTranslate2-converted model ID + from the Hugging Face Hub (e.g. guillaumekln/faster-whisper-large-v2). output_dir: Directory where the model should be saved. If not set, the model is saved in the cache directory. local_files_only: If True, avoid downloading the file and return the path to the local @@ -57,12 +59,16 @@ def download_model( Raises: ValueError: if the model size is invalid. """ - if size not in _MODELS: - raise ValueError( - "Invalid model size '%s', expected one of: %s" % (size, ", ".join(_MODELS)) - ) + if re.match(r".*/.*", size_or_id): + repo_id = size_or_id + else: + if size_or_id not in _MODELS: + raise ValueError( + "Invalid model size '%s', expected one of: %s" + % (size_or_id, ", ".join(_MODELS)) + ) - repo_id = "guillaumekln/faster-whisper-%s" % size + repo_id = "guillaumekln/faster-whisper-%s" % size_or_id allow_patterns = [ "config.json",