Automatically download converted models from the Hugging Face Hub (#70)

* Automatically download converted models from the Hugging Face Hub

* Remove unused import

* Remove non needed requirements in dev mode

* Remove extra index URL when pip install in CI

* Allow downloading to a specific directory

* Update docstring

* Add argument to disable the progess bars

* Fix typo in docstring
This commit is contained in:
Guillaume Klein
2023-03-24 10:55:55 +01:00
committed by GitHub
parent 523ae2180f
commit de7682a2f0
10 changed files with 105 additions and 53 deletions

View File

@@ -25,7 +25,7 @@ jobs:
- name: Install module - name: Install module
run: | run: |
pip install wheel pip install wheel
pip install .[dev] --extra-index-url https://download.pytorch.org/whl/cpu pip install -e .[dev]
- name: Check code format with Black - name: Check code format with Black
run: | run: |
@@ -55,11 +55,11 @@ jobs:
- name: Install module - name: Install module
run: | run: |
pip install wheel pip install wheel
pip install .[dev] --extra-index-url https://download.pytorch.org/whl/cpu pip install -e .[dev]
- name: Run pytest - name: Run pytest
run: | run: |
pytest -v tests/test.py pytest -v tests/
build-and-push-package: build-and-push-package:

View File

@@ -44,12 +44,6 @@ The module can be installed from [PyPI](https://pypi.org/project/faster-whisper/
pip install faster-whisper pip install faster-whisper
``` ```
The model conversion script requires the modules `transformers` and `torch` which can be installed with the `[conversion]` extra requirement:
```bash
pip install faster-whisper[conversion]
```
**Other installation methods:** **Other installation methods:**
```bash ```bash
@@ -70,35 +64,20 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst
## Usage ## Usage
### Model conversion
A Whisper model should be first converted into the CTranslate2 format. We provide a script to download and convert models from the [Hugging Face model repository](https://huggingface.co/models?sort=downloads&search=whisper).
For example the command below converts the "large-v2" Whisper model and saves the weights in FP16:
```bash
ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper-large-v2-ct2 \
--copy_files tokenizer.json --quantization float16
```
If the option `--copy_files tokenizer.json` is not used, the tokenizer configuration is automatically downloaded when the model is loaded later.
Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).
### Transcription ### Transcription
```python ```python
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
model_path = "whisper-large-v2-ct2/" model_size = "large-v2"
# Run on GPU with FP16 # Run on GPU with FP16
model = WhisperModel(model_path, device="cuda", compute_type="float16") model = WhisperModel(model_size, device="cuda", compute_type="float16")
# or run on GPU with INT8 # or run on GPU with INT8
# model = WhisperModel(model_path, device="cuda", compute_type="int8_float16") # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8 # or run on CPU with INT8
# model = WhisperModel(model_path, device="cpu", compute_type="int8") # model = WhisperModel(model_size, device="cpu", compute_type="int8")
segments, info = model.transcribe("audio.mp3", beam_size=5) segments, info = model.transcribe("audio.mp3", beam_size=5)
@@ -120,6 +99,26 @@ for segment in segments:
See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation. See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation.
## Model conversion
When loading a model from its size such as `WhisperModel("large-v2")`, the correspondig CTranslate2 model is automatically downloaded from the [Hugging Face Hub](https://huggingface.co/guillaumekln).
We also provide a script to convert any Whisper models compatible with the Transformers library. They could be the original OpenAI models or user fine-tuned models.
For example the command below converts the [original "large-v2" Whisper model](https://huggingface.co/openai/whisper-large-v2) and saves the weights in FP16:
```bash
pip install transformers[torch]>=4.23
ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper-large-v2-ct2 \
--copy_files tokenizer.json --quantization float16
```
* The option `--model` accepts a model name on the Hub or a path to a model directory.
* If the option `--copy_files tokenizer.json` is not used, the tokenizer configuration is automatically downloaded when the model is loaded later.
Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).
## Comparing performance against other implementations ## Comparing performance against other implementations
If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular: If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular:

View File

@@ -1,9 +1,10 @@
from faster_whisper.audio import decode_audio from faster_whisper.audio import decode_audio
from faster_whisper.transcribe import WhisperModel from faster_whisper.transcribe import WhisperModel
from faster_whisper.utils import format_timestamp from faster_whisper.utils import download_model, format_timestamp
__all__ = [ __all__ = [
"decode_audio", "decode_audio",
"WhisperModel", "WhisperModel",
"download_model",
"format_timestamp", "format_timestamp",
] ]

View File

@@ -11,6 +11,7 @@ import tokenizers
from faster_whisper.audio import decode_audio from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model
class Word(NamedTuple): class Word(NamedTuple):
@@ -57,7 +58,7 @@ class TranscriptionOptions(NamedTuple):
class WhisperModel: class WhisperModel:
def __init__( def __init__(
self, self,
model_path: str, model_size_or_path: str,
device: str = "auto", device: str = "auto",
device_index: Union[int, List[int]] = 0, device_index: Union[int, List[int]] = 0,
compute_type: str = "default", compute_type: str = "default",
@@ -67,7 +68,9 @@ class WhisperModel:
"""Initializes the Whisper model. """Initializes the Whisper model.
Args: Args:
model_path: Path to the converted model. model_size_or_path: Size of the model to use (e.g. "large-v2", "small", "tiny.en", etc.)
or a path to a converted model directory. When a size is configured, the converted
model is downloaded from the Hugging Face Hub.
device: Device to use for computation ("cpu", "cuda", "auto"). device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use. device_index: Device ID to use.
The model can also be loaded on multiple GPUs by passing a list of IDs The model can also be loaded on multiple GPUs by passing a list of IDs
@@ -82,6 +85,11 @@ class WhisperModel:
(concurrent calls to self.model.generate() will run in parallel). (concurrent calls to self.model.generate() will run in parallel).
This can improve the global throughput at the cost of increased memory usage. This can improve the global throughput at the cost of increased memory usage.
""" """
if os.path.isdir(model_size_or_path):
model_path = model_size_or_path
else:
model_path = download_model(model_size_or_path)
self.model = ctranslate2.models.Whisper( self.model = ctranslate2.models.Whisper(
model_path, model_path,
device=device, device=device,

View File

@@ -1,3 +1,42 @@
from typing import Optional
import huggingface_hub
from tqdm.auto import tqdm
def download_model(
size: str,
output_dir: Optional[str] = None,
show_progress_bars: bool = True,
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
The model is downloaded from https://huggingface.co/guillaumekln.
Args:
size: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
medium, medium.en, or large-v2).
output_dir: Directory where the model should be saved. If not set, the model is saved in
the standard Hugging Face cache directory.
show_progress_bars: Show the tqdm progress bars during the download.
Returns:
The path to the downloaded model.
"""
repo_id = "guillaumekln/faster-whisper-%s" % size
kwargs = {}
if output_dir is not None:
kwargs["local_dir"] = output_dir
kwargs["local_dir_use_symlinks"] = False
if not show_progress_bars:
kwargs["tqdm_class"] = disabled_tqdm
return huggingface_hub.snapshot_download(repo_id, **kwargs)
def format_timestamp( def format_timestamp(
seconds: float, seconds: float,
always_include_hours: bool = False, always_include_hours: bool = False,
@@ -19,3 +58,9 @@ def format_timestamp(
return ( return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
) )
class disabled_tqdm(tqdm):
def __init__(self, *args, **kwargs):
kwargs["disable"] = True
super().__init__(*args, **kwargs)

View File

@@ -1,3 +1,4 @@
av==10.* av==10.*
ctranslate2>=3.10,<4 ctranslate2>=3.10,<4
huggingface_hub>=0.13
tokenizers==0.13.* tokenizers==0.13.*

View File

@@ -48,8 +48,7 @@ setup(
install_requires=install_requires, install_requires=install_requires,
extras_require={ extras_require={
"conversion": conversion_requires, "conversion": conversion_requires,
"dev": conversion_requires "dev": [
+ [
"black==23.*", "black==23.*",
"flake8==6.*", "flake8==6.*",
"isort==5.*", "isort==5.*",

View File

@@ -1,6 +1,5 @@
import os import os
import ctranslate2
import pytest import pytest
@@ -12,20 +11,3 @@ def data_dir():
@pytest.fixture @pytest.fixture
def jfk_path(data_dir): def jfk_path(data_dir):
return os.path.join(data_dir, "jfk.flac") return os.path.join(data_dir, "jfk.flac")
@pytest.fixture(scope="session")
def tiny_model_dir(tmp_path_factory):
model_path = str(tmp_path_factory.mktemp("data") / "model")
convert_model("tiny", model_path)
return model_path
def convert_model(size, output_dir):
name = "openai/whisper-%s" % size
ctranslate2.converters.TransformersConverter(
name,
copy_files=["tokenizer.json"],
load_as_float16=True,
).convert(output_dir, quantization="float16")

View File

@@ -1,8 +1,8 @@
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
def test_transcribe(tiny_model_dir, jfk_path): def test_transcribe(jfk_path):
model = WhisperModel(tiny_model_dir) model = WhisperModel("tiny")
segments, info = model.transcribe(jfk_path, word_timestamps=True) segments, info = model.transcribe(jfk_path, word_timestamps=True)
assert info.language == "en" assert info.language == "en"

17
tests/test_utils.py Normal file
View File

@@ -0,0 +1,17 @@
import os
from faster_whisper import download_model
def test_download_model(tmpdir):
output_dir = str(tmpdir.join("model"))
model_dir = download_model("tiny", output_dir=output_dir)
assert model_dir == output_dir
assert os.path.isdir(model_dir)
assert not os.path.islink(model_dir)
for filename in os.listdir(model_dir):
path = os.path.join(model_dir, filename)
assert not os.path.islink(path)