Compare commits

..

3 Commits

Author SHA1 Message Date
9f24e2c735 Merge branch 'master' into prompt 2023-06-24 18:03:05 +08:00
9a646b69e6 format code 2023-04-20 02:00:57 +08:00
49af9564ab Ignore repeated prompt 2023-04-20 01:49:10 +08:00
15 changed files with 148 additions and 778 deletions

View File

@@ -7,7 +7,7 @@ Contributions are welcome! Here are some pointers to help you install the librar
We recommend installing the module in editable mode with the `dev` extra requirements: We recommend installing the module in editable mode with the `dev` extra requirements:
```bash ```bash
git clone https://github.com/SYSTRAN/faster-whisper.git git clone https://github.com/guillaumekln/faster-whisper.git
cd faster-whisper/ cd faster-whisper/
pip install -e .[dev] pip install -e .[dev]
``` ```

View File

@@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2023 SYSTRAN Copyright (c) 2023 Guillaume Klein
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

146
README.md
View File

@@ -1,4 +1,4 @@
[![CI](https://github.com/SYSTRAN/faster-whisper/workflows/CI/badge.svg)](https://github.com/SYSTRAN/faster-whisper/actions?query=workflow%3ACI) [![PyPI version](https://badge.fury.io/py/faster-whisper.svg)](https://badge.fury.io/py/faster-whisper) [![CI](https://github.com/guillaumekln/faster-whisper/workflows/CI/badge.svg)](https://github.com/guillaumekln/faster-whisper/actions?query=workflow%3ACI) [![PyPI version](https://badge.fury.io/py/faster-whisper.svg)](https://badge.fury.io/py/faster-whisper)
# Faster Whisper transcription with CTranslate2 # Faster Whisper transcription with CTranslate2
@@ -8,13 +8,11 @@ This implementation is up to 4 times faster than [openai/whisper](https://github
## Benchmark ## Benchmark
### Whisper
For reference, here's the time and memory usage that are required to transcribe [**13 minutes**](https://www.youtube.com/watch?v=0u7tTptBo9I) of audio using different implementations: For reference, here's the time and memory usage that are required to transcribe [**13 minutes**](https://www.youtube.com/watch?v=0u7tTptBo9I) of audio using different implementations:
* [openai/whisper](https://github.com/openai/whisper)@[6dea21fd](https://github.com/openai/whisper/commit/6dea21fd7f7253bfe450f1e2512a0fe47ee2d258) * [openai/whisper](https://github.com/openai/whisper)@[6dea21fd](https://github.com/openai/whisper/commit/6dea21fd7f7253bfe450f1e2512a0fe47ee2d258)
* [whisper.cpp](https://github.com/ggerganov/whisper.cpp)@[3b010f9](https://github.com/ggerganov/whisper.cpp/commit/3b010f9bed9a6068609e9faf52383aea792b0362) * [whisper.cpp](https://github.com/ggerganov/whisper.cpp)@[3b010f9](https://github.com/ggerganov/whisper.cpp/commit/3b010f9bed9a6068609e9faf52383aea792b0362)
* [faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[cce6b53e](https://github.com/SYSTRAN/faster-whisper/commit/cce6b53e4554f71172dad188c45f10fb100f6e3e) * [faster-whisper](https://github.com/guillaumekln/faster-whisper)@[cce6b53e](https://github.com/guillaumekln/faster-whisper/commit/cce6b53e4554f71172dad188c45f10fb100f6e3e)
### Large-v2 model on GPU ### Large-v2 model on GPU
@@ -38,71 +36,6 @@ For reference, here's the time and memory usage that are required to transcribe
*Executed with 8 threads on a Intel(R) Xeon(R) Gold 6226R.* *Executed with 8 threads on a Intel(R) Xeon(R) Gold 6226R.*
### Distil-whisper
| Implementation | Precision | Beam size | Time | Gigaspeech WER |
| --- | --- | --- | --- | --- |
| distil-whisper/distil-large-v2 | fp16 | 4 |- | 10.36 |
| [faster-distil-large-v2](https://huggingface.co/Systran/faster-distil-whisper-large-v2) | fp16 | 5 | - | 10.28 |
| distil-whisper/distil-medium.en | fp16 | 4 | - | 11.21 |
| [faster-distil-medium.en](https://huggingface.co/Systran/faster-distil-whisper-medium.en) | fp16 | 5 | - | 11.21 |
*Executed with CUDA 11.4 on a NVIDIA 3090.*
<details>
<summary>testing details (click to expand)</summary>
For `distil-whisper/distil-large-v2`, the WER is tested with code sample from [link](https://huggingface.co/distil-whisper/distil-large-v2#evaluation). for `faster-distil-whisper`, the WER is tested with setting:
```python
from faster_whisper import WhisperModel
model_size = "distil-large-v2"
# model_size = "distil-medium.en"
# Run on GPU with FP16
model = WhisperModel(model_size, device="cuda", compute_type="float16")
segments, info = model.transcribe("audio.mp3", beam_size=5, language="en")
```
</details>
## Requirements
* Python 3.8 or greater
Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package.
### GPU
GPU execution requires the following NVIDIA libraries to be installed:
* [cuBLAS for CUDA 11](https://developer.nvidia.com/cublas)
* [cuDNN 8 for CUDA 11](https://developer.nvidia.com/cudnn)
There are multiple ways to install these libraries. The recommended way is described in the official NVIDIA documentation, but we also suggest other installation methods below.
<details>
<summary>Other installation methods (click to expand)</summary>
#### Use Docker
The libraries are installed in this official NVIDIA Docker image: `nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04`.
#### Install with `pip` (Linux only)
On Linux these libraries can be installed with `pip`. Note that `LD_LIBRARY_PATH` must be set before launching Python.
```bash
pip install nvidia-cublas-cu11 nvidia-cudnn-cu11
export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
```
#### Download the libraries from Purfview's repository (Windows & Linux)
Purfview's [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) provides the required NVIDIA libraries for Windows & Linux in a [single archive](https://github.com/Purfview/whisper-standalone-win/releases/tag/libs). Decompress the archive and place the libraries in a directory included in the `PATH`.
</details>
## Installation ## Installation
The module can be installed from [PyPI](https://pypi.org/project/faster-whisper/): The module can be installed from [PyPI](https://pypi.org/project/faster-whisper/):
@@ -111,31 +44,26 @@ The module can be installed from [PyPI](https://pypi.org/project/faster-whisper/
pip install faster-whisper pip install faster-whisper
``` ```
<details> **Other installation methods:**
<summary>Other installation methods (click to expand)</summary>
### Install the master branch
```bash ```bash
pip install --force-reinstall "faster-whisper @ https://github.com/SYSTRAN/faster-whisper/archive/refs/heads/master.tar.gz" # Install the master branch:
pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/refs/heads/master.tar.gz"
# Install a specific commit:
pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/a4f1cc8f11433e454c3934442b5e1a4ed5e865c3.tar.gz"
``` ```
### Install a specific commit ### GPU support
```bash GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
pip install --force-reinstall "faster-whisper @ https://github.com/SYSTRAN/faster-whisper/archive/a4f1cc8f11433e454c3934442b5e1a4ed5e865c3.tar.gz"
```
</details>
## Usage ## Usage
### Faster-whisper
```python ```python
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
model_size = "large-v3" model_size = "large-v2"
# Run on GPU with FP16 # Run on GPU with FP16
model = WhisperModel(model_size, device="cuda", compute_type="float16") model = WhisperModel(model_size, device="cuda", compute_type="float16")
@@ -159,25 +87,6 @@ for segment in segments:
segments, _ = model.transcribe("audio.mp3") segments, _ = model.transcribe("audio.mp3")
segments = list(segments) # The transcription will actually run here. segments = list(segments) # The transcription will actually run here.
``` ```
### Faster Distil-Whisper
The Distil-Whisper checkpoints are compatible with the Faster-Whisper package. In particular, the latest [distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3)
checkpoint is intrinsically designed to work with the Faster-Whisper transcription algorithm. The following code snippet
demonstrates how to run inference with distil-large-v3 on a specified audio file:
```python
from faster_whisper import WhisperModel
model_size = "distil-large-v3"
model = WhisperModel(model_size, device="cuda", compute_type="float16")
segments, info = model.transcribe("audio.mp3", beam_size=5, language="en", condition_on_previous_text=False)
for segment in segments:
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
```
For more information about the distil-large-v3 model, refer to the original [model card](https://huggingface.co/distil-whisper/distil-large-v3).
### Word-level timestamps ### Word-level timestamps
@@ -197,7 +106,7 @@ The library integrates the [Silero VAD](https://github.com/snakers4/silero-vad)
segments, _ = model.transcribe("audio.mp3", vad_filter=True) segments, _ = model.transcribe("audio.mp3", vad_filter=True)
``` ```
The default behavior is conservative and only removes silence longer than 2 seconds. See the available VAD parameters and default values in the [source code](https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py). They can be customized with the dictionary argument `vad_parameters`: The default behavior is conservative and only removes silence longer than 2 seconds. See the available VAD parameters and default values in the [source code](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/vad.py). They can be customized with the dictionary argument `vad_parameters`:
```python ```python
segments, _ = model.transcribe( segments, _ = model.transcribe(
@@ -220,38 +129,31 @@ logging.getLogger("faster_whisper").setLevel(logging.DEBUG)
### Going further ### Going further
See more model and transcription options in the [`WhisperModel`](https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation. See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation.
## Community integrations ## Community integrations
Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list! Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list!
* [WhisperX](https://github.com/m-bain/whisperX) is an award-winning Python library that offers speaker diarization and accurate word-level timestamps using wav2vec2 alignment
* [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper. * [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper.
* [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo. * [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo.
* [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) Standalone CLI executables of faster-whisper for Windows, Linux & macOS. * [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) contains the portable ready to run binaries of faster-whisper for Windows.
* [asr-sd-pipeline](https://github.com/hedrergudene/asr-sd-pipeline) provides a scalable, modular, end to end multi-speaker speech to text solution implemented using AzureML pipelines. * [asr-sd-pipeline](https://github.com/hedrergudene/asr-sd-pipeline) provides a scalable, modular, end to end multi-speaker speech to text solution implemented using AzureML pipelines.
* [Open-Lyrics](https://github.com/zh-plus/Open-Lyrics) is a Python library that transcribes voice files using faster-whisper, and translates/polishes the resulting text into `.lrc` files in the desired language using OpenAI-GPT. * [Open-Lyrics](https://github.com/zh-plus/Open-Lyrics) is a Python library that transcribes voice files using faster-whisper, and translates/polishes the resulting text into `.lrc` files in the desired language using OpenAI-GPT.
* [wscribe](https://github.com/geekodour/wscribe) is a flexible transcript generation tool supporting faster-whisper, it can export word level transcript and the exported transcript then can be edited with [wscribe-editor](https://github.com/geekodour/wscribe-editor)
* [aTrain](https://github.com/BANDAS-Center/aTrain) is a graphical user interface implementation of faster-whisper developed at the BANDAS-Center at the University of Graz for transcription and diarization in Windows ([Windows Store App](https://apps.microsoft.com/detail/atrain/9N15Q44SZNS2)) and Linux.
* [Whisper-Streaming](https://github.com/ufal/whisper_streaming) implements real-time mode for offline Whisper-like speech-to-text models with faster-whisper as the most recommended back-end. It implements a streaming policy with self-adaptive latency based on the actual source complexity, and demonstrates the state of the art.
* [WhisperLive](https://github.com/collabora/WhisperLive) is a nearly-live implementation of OpenAI's Whisper which uses faster-whisper as the backend to transcribe audio in real-time.
* [Faster-Whisper-Transcriber](https://github.com/BBC-Esq/ctranslate2-faster-whisper-transcriber) is a simple but reliable voice transcriber that provides a user-friendly interface.
## Model conversion ## Model conversion
When loading a model from its size such as `WhisperModel("large-v3")`, the corresponding CTranslate2 model is automatically downloaded from the [Hugging Face Hub](https://huggingface.co/Systran). When loading a model from its size such as `WhisperModel("large-v2")`, the correspondig CTranslate2 model is automatically downloaded from the [Hugging Face Hub](https://huggingface.co/guillaumekln).
We also provide a script to convert any Whisper models compatible with the Transformers library. They could be the original OpenAI models or user fine-tuned models. We also provide a script to convert any Whisper models compatible with the Transformers library. They could be the original OpenAI models or user fine-tuned models.
For example the command below converts the [original "large-v3" Whisper model](https://huggingface.co/openai/whisper-large-v3) and saves the weights in FP16: For example the command below converts the [original "large-v2" Whisper model](https://huggingface.co/openai/whisper-large-v2) and saves the weights in FP16:
```bash ```bash
pip install transformers[torch]>=4.23 pip install transformers[torch]>=4.23
ct2-transformers-converter --model openai/whisper-large-v3 --output_dir whisper-large-v3-ct2 ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper-large-v2-ct2 \
--copy_files tokenizer.json preprocessor_config.json --quantization float16 --copy_files tokenizer.json --quantization float16
``` ```
* The option `--model` accepts a model name on the Hub or a path to a model directory. * The option `--model` accepts a model name on the Hub or a path to a model directory.
@@ -259,18 +161,6 @@ ct2-transformers-converter --model openai/whisper-large-v3 --output_dir whisper-
Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html). Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html).
### Load a converted model
1. Directly load the model from a local directory:
```python
model = faster_whisper.WhisperModel("whisper-large-v3-ct2")
```
2. [Upload your model to the Hugging Face Hub](https://huggingface.co/docs/transformers/model_sharing#upload-with-the-web-interface) and load it from its name:
```python
model = faster_whisper.WhisperModel("username/whisper-large-v3-ct2")
```
## Comparing performance against other implementations ## Comparing performance against other implementations
If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular: If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular:

View File

@@ -1,10 +1,9 @@
from faster_whisper.audio import decode_audio from faster_whisper.audio import decode_audio
from faster_whisper.transcribe import WhisperModel from faster_whisper.transcribe import WhisperModel
from faster_whisper.utils import available_models, download_model, format_timestamp from faster_whisper.utils import download_model, format_timestamp
from faster_whisper.version import __version__ from faster_whisper.version import __version__
__all__ = [ __all__ = [
"available_models",
"decode_audio", "decode_audio",
"WhisperModel", "WhisperModel",
"download_model", "download_model",

View File

@@ -6,7 +6,6 @@ system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly. However, the API is quite low-level so we need to manipulate audio frames directly.
""" """
import gc
import io import io
import itertools import itertools
@@ -43,7 +42,7 @@ def decode_audio(
raw_buffer = io.BytesIO() raw_buffer = io.BytesIO()
dtype = None dtype = None
with av.open(input_file, mode="r", metadata_errors="ignore") as container: with av.open(input_file, metadata_errors="ignore") as container:
frames = container.decode(audio=0) frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames) frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000) frames = _group_frames(frames, 500000)
@@ -54,11 +53,6 @@ def decode_audio(
dtype = array.dtype dtype = array.dtype
raw_buffer.write(array) raw_buffer.write(array)
# It appears that some objects related to the resampler are not freed
# unless the garbage collector is manually run.
del resampler
gc.collect()
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype) audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
# Convert s16 back to f32. # Convert s16 back to f32.
@@ -102,18 +96,3 @@ def _resample_frames(frames, resampler):
# Add None to flush the resampler. # Add None to flush the resampler.
for frame in itertools.chain(frames, [None]): for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame) yield from resampler.resample(frame)
def pad_or_trim(array, length: int, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array

View File

@@ -142,15 +142,11 @@ class FeatureExtractor:
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins] data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return data.T return data.T
def __call__(self, waveform, padding=True, chunk_length=None): def __call__(self, waveform, padding=True):
""" """
Compute the log-Mel spectrogram of the provided audio, gives similar results Compute the log-Mel spectrogram of the provided audio, gives similar results
whisper's original torch implementation with 1e-5 tolerance. whisper's original torch implementation with 1e-5 tolerance.
""" """
if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length
if padding: if padding:
waveform = np.pad(waveform, [(0, self.n_samples)]) waveform = np.pad(waveform, [(0, self.n_samples)])

View File

@@ -19,21 +19,15 @@ class Tokenizer:
self.tokenizer = tokenizer self.tokenizer = tokenizer
if multilingual: if multilingual:
if task not in _TASKS:
raise ValueError(
"'%s' is not a valid task (accepted tasks: %s)"
% (task, ", ".join(_TASKS))
)
if language not in _LANGUAGE_CODES:
raise ValueError(
"'%s' is not a valid language code (accepted language codes: %s)"
% (language, ", ".join(_LANGUAGE_CODES))
)
self.task = self.tokenizer.token_to_id("<|%s|>" % task) self.task = self.tokenizer.token_to_id("<|%s|>" % task)
self.language = self.tokenizer.token_to_id("<|%s|>" % language) if self.task is None:
raise ValueError("%s is not a valid task" % task)
self.language_code = language self.language_code = language
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
if self.language is None:
raise ValueError("%s is not a valid language code" % language)
else: else:
self.task = None self.task = None
self.language = None self.language = None
@@ -108,7 +102,7 @@ class Tokenizer:
def split_to_word_tokens( def split_to_word_tokens(
self, tokens: List[int] self, tokens: List[int]
) -> Tuple[List[str], List[List[int]]]: ) -> Tuple[List[str], List[List[int]]]:
if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}: if self.language_code in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words # These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any # without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points # position where the tokens are decoded as valid unicode points
@@ -167,112 +161,3 @@ class Tokenizer:
word_tokens[-1].extend(subword_tokens) word_tokens[-1].extend(subword_tokens)
return words, word_tokens return words, word_tokens
_TASKS = (
"transcribe",
"translate",
)
_LANGUAGE_CODES = (
"af",
"am",
"ar",
"as",
"az",
"ba",
"be",
"bg",
"bn",
"bo",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"el",
"en",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"gl",
"gu",
"ha",
"haw",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jw",
"ka",
"kk",
"km",
"kn",
"ko",
"la",
"lb",
"ln",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"nn",
"no",
"oc",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"sa",
"sd",
"si",
"sk",
"sl",
"sn",
"so",
"sq",
"sr",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tk",
"tl",
"tr",
"tt",
"uk",
"ur",
"uz",
"vi",
"yi",
"yo",
"zh",
"yue",
)

View File

@@ -1,20 +1,18 @@
import itertools import itertools
import json
import logging import logging
import os import os
import zlib import zlib
from inspect import signature
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
import ctranslate2 import ctranslate2
import numpy as np import numpy as np
import tokenizers import tokenizers
from faster_whisper.audio import decode_audio, pad_or_trim from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger from faster_whisper.utils import download_model, format_timestamp, get_logger
from faster_whisper.vad import ( from faster_whisper.vad import (
SpeechTimestampsMap, SpeechTimestampsMap,
VadOptions, VadOptions,
@@ -49,13 +47,10 @@ class TranscriptionOptions(NamedTuple):
best_of: int best_of: int
patience: float patience: float
length_penalty: float length_penalty: float
repetition_penalty: float
no_repeat_ngram_size: int
log_prob_threshold: Optional[float] log_prob_threshold: Optional[float]
no_speech_threshold: Optional[float] no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float] compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool condition_on_previous_text: bool
prompt_reset_on_temperature: float
temperatures: List[float] temperatures: List[float]
initial_prompt: Optional[Union[str, Iterable[int]]] initial_prompt: Optional[Union[str, Iterable[int]]]
prefix: Optional[str] prefix: Optional[str]
@@ -66,16 +61,12 @@ class TranscriptionOptions(NamedTuple):
word_timestamps: bool word_timestamps: bool
prepend_punctuations: str prepend_punctuations: str
append_punctuations: str append_punctuations: str
max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]
class TranscriptionInfo(NamedTuple): class TranscriptionInfo(NamedTuple):
language: str language: str
language_probability: float language_probability: float
duration: float duration: float
duration_after_vad: float
all_language_probs: Optional[List[Tuple[str, float]]] all_language_probs: Optional[List[Tuple[str, float]]]
transcription_options: TranscriptionOptions transcription_options: TranscriptionOptions
vad_options: VadOptions vad_options: VadOptions
@@ -97,9 +88,8 @@ class WhisperModel:
Args: Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a small, small.en, medium, medium.en, large-v1, or large-v2) or a path to a converted
converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub. model directory. When a size is configured, the converted model is downloaded
When a size or a model ID is configured, the converted model is downloaded
from the Hugging Face Hub. from the Hugging Face Hub.
device: Device to use for computation ("cpu", "cuda", "auto"). device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use. device_index: Device ID to use.
@@ -147,8 +137,7 @@ class WhisperModel:
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
) )
self.feat_kwargs = self._get_feature_kwargs(model_path) self.feature_extractor = FeatureExtractor()
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
self.num_samples_per_token = self.feature_extractor.hop_length * 2 self.num_samples_per_token = self.feature_extractor.hop_length * 2
self.frames_per_second = ( self.frames_per_second = (
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
@@ -160,27 +149,6 @@ class WhisperModel:
self.time_precision = 0.02 self.time_precision = 0.02
self.max_length = 448 self.max_length = 448
@property
def supported_languages(self) -> List[str]:
"""The languages supported by the model."""
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
def _get_feature_kwargs(self, model_path) -> dict:
preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json")
config = {}
if os.path.isfile(preprocessor_config_file):
try:
with open(preprocessor_config_file, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
config = {k: v for k, v in config.items() if k in valid_keys}
except json.JSONDecodeError as e:
self.logger.warning(
"Could not load preprocessor_config.json: %s", str(e)
)
return config
def transcribe( def transcribe(
self, self,
audio: Union[str, BinaryIO, np.ndarray], audio: Union[str, BinaryIO, np.ndarray],
@@ -190,8 +158,6 @@ class WhisperModel:
best_of: int = 5, best_of: int = 5,
patience: float = 1, patience: float = 1,
length_penalty: float = 1, length_penalty: float = 1,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
temperature: Union[float, List[float], Tuple[float, ...]] = [ temperature: Union[float, List[float], Tuple[float, ...]] = [
0.0, 0.0,
0.2, 0.2,
@@ -204,7 +170,6 @@ class WhisperModel:
log_prob_threshold: Optional[float] = -1.0, log_prob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
prompt_reset_on_temperature: float = 0.5,
initial_prompt: Optional[Union[str, Iterable[int]]] = None, initial_prompt: Optional[Union[str, Iterable[int]]] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
suppress_blank: bool = True, suppress_blank: bool = True,
@@ -216,12 +181,6 @@ class WhisperModel:
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
vad_filter: bool = False, vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None, vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
language_detection_threshold: Optional[float] = None,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]: ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file. """Transcribes an input file.
@@ -235,9 +194,6 @@ class WhisperModel:
best_of: Number of candidates when sampling with non-zero temperature. best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor. patience: Beam search patience factor.
length_penalty: Exponential length penalty constant. length_penalty: Exponential length penalty constant.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
temperature: Temperature for sampling. It can be a tuple of temperatures, temperature: Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`. `compression_ratio_threshold` or `log_prob_threshold`.
@@ -252,8 +208,6 @@ class WhisperModel:
as a prompt for the next window; disabling may make the text inconsistent across as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop, windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync. such as repetition looping or timestamps going out of sync.
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
Arg has effect only if condition_on_previous_text is True.
initial_prompt: Optional text string or iterable of token ids to provide as a initial_prompt: Optional text string or iterable of token ids to provide as a
prompt for the first window. prompt for the first window.
prefix: Optional text to provide as a prefix for the first window. prefix: Optional text to provide as a prefix for the first window.
@@ -273,20 +227,6 @@ class WhisperModel:
https://github.com/snakers4/silero-vad. https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`). parameters and default values in the class `VadOptions`).
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
the maximum will be set by the default max_length.
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
process. The last end timestamp defaults to the end of the file.
vad_filter will be ignored if clip_timestamps is used.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected
language_detection_threshold: If the maximum probability of the language tokens is higher
than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns: Returns:
A tuple with: A tuple with:
@@ -300,24 +240,22 @@ class WhisperModel:
audio = decode_audio(audio, sampling_rate=sampling_rate) audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate duration = audio.shape[0] / sampling_rate
duration_after_vad = duration
self.logger.info( self.logger.info(
"Processing audio with duration %s", format_timestamp(duration) "Processing audio with duration %s", format_timestamp(duration)
) )
if vad_filter and clip_timestamps == "0": if vad_filter:
if vad_parameters is None: if vad_parameters is None:
vad_parameters = VadOptions() vad_parameters = VadOptions()
elif isinstance(vad_parameters, dict): elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters) vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks) audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate
self.logger.info( self.logger.info(
"VAD filter removed %s of audio", "VAD filter removed %s of audio",
format_timestamp(duration - duration_after_vad), format_timestamp(duration - (audio.shape[0] / sampling_rate)),
) )
if self.logger.isEnabledFor(logging.DEBUG): if self.logger.isEnabledFor(logging.DEBUG):
@@ -336,7 +274,7 @@ class WhisperModel:
else: else:
speech_chunks = None speech_chunks = None
features = self.feature_extractor(audio, chunk_length=chunk_length) features = self.feature_extractor(audio)
encoder_output = None encoder_output = None
all_language_probs = None all_language_probs = None
@@ -346,51 +284,15 @@ class WhisperModel:
language = "en" language = "en"
language_probability = 1 language_probability = 1
else: else:
if ( segment = features[:, : self.feature_extractor.nb_max_frames]
language_detection_segments is None
or language_detection_segments < 1
):
language_detection_segments = 1
seek = 0
detected_language_info = {}
content_frames = (
features.shape[-1] - self.feature_extractor.nb_max_frames
)
while (
seek <= content_frames
and seek
< self.feature_extractor.nb_max_frames * language_detection_segments
):
segment = features[
:, seek : seek + self.feature_extractor.nb_max_frames
]
encoder_output = self.encode(segment) encoder_output = self.encode(segment)
# results is a list of tuple[str, float] with language names and # results is a list of tuple[str, float] with language names and
# probabilities. # probabilities.
results = self.model.detect_language(encoder_output)[0] results = self.model.detect_language(encoder_output)[0]
# Parse language names to strip out markers # Parse language names to strip out markers
all_language_probs = [ all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
(token[2:-2], prob) for (token, prob) in results
]
# Get top language token and probability # Get top language token and probability
language, language_probability = all_language_probs[0] language, language_probability = all_language_probs[0]
if (
language_detection_threshold is None
or language_probability > language_detection_threshold
):
break
detected_language_info.setdefault(language, []).append(
language_probability
)
seek += segment.shape[-1]
else:
# If no language detected for all segments, the majority vote of the highest
# projected languages for all segments is used to determine the language.
language = max(
detected_language_info,
key=lambda lang: len(detected_language_info[lang]),
)
language_probability = max(detected_language_info[language])
self.logger.info( self.logger.info(
"Detected language '%s' with probability %.2f", "Detected language '%s' with probability %.2f",
@@ -398,13 +300,6 @@ class WhisperModel:
language_probability, language_probability,
) )
else: else:
if not self.model.is_multilingual and language != "en":
self.logger.warning(
"The current model is English-only but the language parameter is set to '%s'; "
"using 'en' instead." % language
)
language = "en"
language_probability = 1 language_probability = 1
tokenizer = Tokenizer( tokenizer = Tokenizer(
@@ -419,13 +314,10 @@ class WhisperModel:
best_of=best_of, best_of=best_of,
patience=patience, patience=patience,
length_penalty=length_penalty, length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
log_prob_threshold=log_prob_threshold, log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold, no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold, compression_ratio_threshold=compression_ratio_threshold,
condition_on_previous_text=condition_on_previous_text, condition_on_previous_text=condition_on_previous_text,
prompt_reset_on_temperature=prompt_reset_on_temperature,
temperatures=( temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature] temperature if isinstance(temperature, (list, tuple)) else [temperature]
), ),
@@ -438,9 +330,6 @@ class WhisperModel:
word_timestamps=word_timestamps, word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations, prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations, append_punctuations=append_punctuations,
max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
) )
segments = self.generate_segments(features, tokenizer, options, encoder_output) segments = self.generate_segments(features, tokenizer, options, encoder_output)
@@ -452,7 +341,6 @@ class WhisperModel:
language=language, language=language,
language_probability=language_probability, language_probability=language_probability,
duration=duration, duration=duration,
duration_after_vad=duration_after_vad,
transcription_options=options, transcription_options=options,
vad_options=vad_parameters, vad_options=vad_parameters,
all_language_probs=all_language_probs, all_language_probs=all_language_probs,
@@ -468,34 +356,10 @@ class WhisperModel:
encoder_output: Optional[ctranslate2.StorageView] = None, encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]: ) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
if isinstance(options.clip_timestamps, str):
TranscriptionOptions.clip_timestamps = [
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]
seek_points: List[int] = [
round(ts * self.frames_per_second) for ts in options.clip_timestamps
]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(
zip(seek_points[::2], seek_points[1::2])
)
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
idx = 0 idx = 0
clip_idx = 0 seek = 0
seek = seek_clips[clip_idx][0]
all_tokens = [] all_tokens = []
all_prompt_text = []
prompt_reset_since = 0 prompt_reset_since = 0
if options.initial_prompt is not None: if options.initial_prompt is not None:
@@ -506,35 +370,13 @@ class WhisperModel:
else: else:
all_tokens.extend(options.initial_prompt) all_tokens.extend(options.initial_prompt)
last_speech_timestamp = 0.0 while seek < content_frames:
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek_clip_end > content_frames:
seek_clip_end = content_frames
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = seek * self.feature_extractor.time_per_frame time_offset = seek * self.feature_extractor.time_per_frame
window_end_time = float( segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
(seek + self.feature_extractor.nb_max_frames)
* self.feature_extractor.time_per_frame
)
segment_size = min( segment_size = min(
self.feature_extractor.nb_max_frames, self.feature_extractor.nb_max_frames, content_frames - seek
content_frames - seek,
seek_clip_end - seek,
) )
segment = features[:, seek : seek + segment_size]
segment_duration = segment_size * self.feature_extractor.time_per_frame segment_duration = segment_size * self.feature_extractor.time_per_frame
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)
if self.logger.isEnabledFor(logging.DEBUG): if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug( self.logger.debug(
@@ -549,7 +391,7 @@ class WhisperModel:
prefix=options.prefix if seek == 0 else None, prefix=options.prefix if seek == 0 else None,
) )
if seek > 0 or encoder_output is None: if encoder_output is None:
encoder_output = self.encode(segment) encoder_output = self.encode(segment)
( (
@@ -586,33 +428,10 @@ class WhisperModel:
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
single_timestamp_ending = ( single_timestamp_ending = (
len(tokens) >= 2 len(tokens) >= 2
and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] and tokens[-2] < tokenizer.timestamp_begin
and tokens[-1] >= tokenizer.timestamp_begin
) )
consecutive_timestamps = [ consecutive_timestamps = [
@@ -692,65 +511,21 @@ class WhisperModel:
segment_size, segment_size,
options.prepend_punctuations, options.prepend_punctuations,
options.append_punctuations, options.append_punctuations,
last_speech_timestamp=last_speech_timestamp,
) )
if not single_timestamp_ending: word_end_timestamps = [
last_word_end = get_end(current_segments) w["end"] for s in current_segments for w in s["words"]
if last_word_end is not None and last_word_end > time_offset: ]
seek = round(last_word_end * self.frames_per_second)
# skip silence before possible hallucinations if not single_timestamp_ending and len(word_end_timestamps) > 0:
if options.hallucination_silence_threshold is not None: seek_shift = round(
threshold = options.hallucination_silence_threshold (word_end_timestamps[-1] - time_offset) * self.frames_per_second
)
# if first segment might be a hallucination, skip leading silence if seek_shift > 0:
first_segment = next_words_segment(current_segments) seek = previous_seek + seek_shift
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * self.frames_per_second)
continue
# skip silence before any possible hallucination that is surrounded encoder_output = None
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* self.frames_per_second
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
for segment in current_segments: for segment in current_segments:
tokens = segment["tokens"] tokens = segment["tokens"]
@@ -759,7 +534,15 @@ class WhisperModel:
if segment["start"] == segment["end"] or not text.strip(): if segment["start"] == segment["end"] or not text.strip():
continue continue
check_prompt_num = 1
if all(
[
text.strip() != i.strip()
for i in all_prompt_text[-check_prompt_num:]
]
):
all_tokens.extend(tokens) all_tokens.extend(tokens)
all_prompt_text.append(text)
idx += 1 idx += 1
yield Segment( yield Segment(
@@ -780,17 +563,7 @@ class WhisperModel:
), ),
) )
if ( if not options.condition_on_previous_text or temperature > 0.5:
not options.condition_on_previous_text
or temperature > options.prompt_reset_on_temperature
):
if options.condition_on_previous_text:
self.logger.debug(
"Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
temperature,
options.prompt_reset_on_temperature,
)
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_tokens)
def encode(self, features: np.ndarray) -> ctranslate2.StorageView: def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
@@ -810,28 +583,14 @@ class WhisperModel:
tokenizer: Tokenizer, tokenizer: Tokenizer,
options: TranscriptionOptions, options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
decode_result = None result = None
all_results = [] avg_logprob = None
below_cr_threshold_results = [] final_temperature = None
compression_ratio = None
max_initial_timestamp_index = int( max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision) round(options.max_initial_timestamp / self.time_precision)
) )
if options.max_new_tokens is not None:
max_length = len(prompt) + options.max_new_tokens
else:
max_length = self.max_length
if max_length > self.max_length:
raise ValueError(
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
f"and `max_new_tokens` is: {max_length}. This exceeds the "
f"`max_length` of the Whisper model: {self.max_length}. "
"You should either reduce the length of your prompt, or "
"reduce the value of `max_new_tokens`, "
f"so that their combined length is less that {self.max_length}."
)
for temperature in options.temperatures: for temperature in options.temperatures:
if temperature > 0: if temperature > 0:
@@ -847,13 +606,12 @@ class WhisperModel:
"patience": options.patience, "patience": options.patience,
} }
final_temperature = temperature
result = self.model.generate( result = self.model.generate(
encoder_output, encoder_output,
[prompt], [prompt],
length_penalty=options.length_penalty, length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty, max_length=self.max_length,
no_repeat_ngram_size=options.no_repeat_ngram_size,
max_length=max_length,
return_scores=True, return_scores=True,
return_no_speech_prob=True, return_no_speech_prob=True,
suppress_blank=options.suppress_blank, suppress_blank=options.suppress_blank,
@@ -872,18 +630,12 @@ class WhisperModel:
text = tokenizer.decode(tokens).strip() text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text) compression_ratio = get_compression_ratio(text)
decode_result = (
result,
avg_logprob,
temperature,
compression_ratio,
)
all_results.append(decode_result)
needs_fallback = False needs_fallback = False
if options.compression_ratio_threshold is not None: if (
if compression_ratio > options.compression_ratio_threshold: options.compression_ratio_threshold is not None
and compression_ratio > options.compression_ratio_threshold
):
needs_fallback = True # too repetitive needs_fallback = True # too repetitive
self.logger.debug( self.logger.debug(
@@ -892,8 +644,6 @@ class WhisperModel:
compression_ratio, compression_ratio,
options.compression_ratio_threshold, options.compression_ratio_threshold,
) )
else:
below_cr_threshold_results.append(decode_result)
if ( if (
options.log_prob_threshold is not None options.log_prob_threshold is not None
@@ -908,30 +658,10 @@ class WhisperModel:
options.log_prob_threshold, options.log_prob_threshold,
) )
if (
options.no_speech_threshold is not None
and result.no_speech_prob > options.no_speech_threshold
and options.log_prob_threshold is not None
and avg_logprob < options.log_prob_threshold
):
needs_fallback = False # silence
if not needs_fallback: if not needs_fallback:
break break
else:
# all failed, select the result with the highest average log probability
decode_result = max(
below_cr_threshold_results or all_results, key=lambda x: x[1]
)
# to pass final temperature for prompt_reset_on_temperature
decode_result = (
decode_result[0],
decode_result[1],
temperature,
decode_result[3],
)
return decode_result return result, avg_logprob, final_temperature, compression_ratio
def get_prompt( def get_prompt(
self, self,
@@ -955,8 +685,6 @@ class WhisperModel:
prefix_tokens = tokenizer.encode(" " + prefix.strip()) prefix_tokens = tokenizer.encode(" " + prefix.strip())
if len(prefix_tokens) >= self.max_length // 2: if len(prefix_tokens) >= self.max_length // 2:
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1] prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
if not without_timestamps:
prompt.append(tokenizer.timestamp_begin)
prompt.extend(prefix_tokens) prompt.extend(prefix_tokens)
return prompt return prompt
@@ -969,8 +697,7 @@ class WhisperModel:
num_frames: int, num_frames: int,
prepend_punctuations: str, prepend_punctuations: str,
append_punctuations: str, append_punctuations: str,
last_speech_timestamp: float, ):
) -> None:
if len(segments) == 0: if len(segments) == 0:
return return
@@ -983,25 +710,6 @@ class WhisperModel:
alignment = self.find_alignment( alignment = self.find_alignment(
tokenizer, text_tokens, encoder_output, num_frames tokenizer, text_tokens, encoder_output, num_frames
) )
word_durations = np.array([word["end"] - word["start"] for word in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(word_durations) > 0:
sentence_end_marks = ".。!?"
# ensure words at sentence boundaries
# are not longer than twice the median word duration.
for i in range(1, len(alignment)):
if alignment[i]["end"] - alignment[i]["start"] > max_duration:
if alignment[i]["word"] in sentence_end_marks:
alignment[i]["end"] = alignment[i]["start"] + max_duration
elif alignment[i - 1]["word"] in sentence_end_marks:
alignment[i]["start"] = alignment[i]["end"] - max_duration
merge_punctuations(alignment, prepend_punctuations, append_punctuations) merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = ( time_offset = (
@@ -1032,52 +740,11 @@ class WhisperModel:
saved_tokens += len(timing["tokens"]) saved_tokens += len(timing["tokens"])
word_index += 1 word_index += 1
# hack: truncate long words at segment boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(words) > 0: if len(words) > 0:
# ensure the first and second word after a pause is not longer than # adjust the segment-level timestamps based on the word-level timestamps
# twice the median word duration.
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
words[0]["end"] - words[0]["start"] > max_duration
or (
len(words) > 1
and words[1]["end"] - words[0]["start"] > max_duration * 2
)
):
if (
len(words) > 1
and words[1]["end"] - words[1]["start"] > max_duration
):
boundary = max(
words[1]["end"] / 2, words[1]["end"] - max_duration
)
words[0]["end"] = words[1]["start"] = boundary
words[0]["start"] = max(0, words[0]["end"] - max_duration)
# prefer the segment-level start timestamp if the first word is too long.
if (
segment["start"] < words[0]["end"]
and segment["start"] - 0.5 > words[0]["start"]
):
words[0]["start"] = max(
0, min(words[0]["end"] - median_duration, segment["start"])
)
else:
segment["start"] = words[0]["start"] segment["start"] = words[0]["start"]
# prefer the segment-level end timestamp if the last word is too long.
if (
segment["end"] > words[-1]["start"]
and segment["end"] + 0.5 < words[-1]["end"]
):
words[-1]["end"] = max(
words[-1]["start"] + median_duration, segment["end"]
)
else:
segment["end"] = words[-1]["end"] segment["end"] = words[-1]["end"]
last_speech_timestamp = segment["end"]
segment["words"] = words segment["words"] = words
def find_alignment( def find_alignment(
@@ -1108,13 +775,6 @@ class WhisperModel:
words, word_tokens = tokenizer.split_to_word_tokens( words, word_tokens = tokenizer.split_to_word_tokens(
text_tokens + [tokenizer.eot] text_tokens + [tokenizer.eot]
) )
if len(word_tokens) <= 1:
# return on eot only
# >>> np.pad([], (1, 0))
# array([0.])
# This results in crashes when we lookup jump_times with float, like
# IndexError: arrays used as indices must be of integer (or boolean) type
return []
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
if len(word_boundaries) <= 1: if len(word_boundaries) <= 1:
return [] return []
@@ -1128,6 +788,22 @@ class WhisperModel:
for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
] ]
# hack: ensure the first and second word is not longer than twice the median word duration.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
end_times[0] = start_times[1] = boundary
if (
len(word_durations) >= 1
and end_times[0] - start_times[0] > max_duration
):
start_times[0] = max(0, end_times[0] - max_duration)
return [ return [
dict( dict(
word=word, tokens=tokens, start=start, end=end, probability=probability word=word, tokens=tokens, start=start, end=end, probability=probability
@@ -1184,10 +860,7 @@ def get_compression_ratio(text: str) -> float:
return len(text_bytes) / len(zlib.compress(text_bytes)) return len(text_bytes) / len(zlib.compress(text_bytes))
def get_suppressed_tokens( def get_suppressed_tokens(tokenizer, suppress_tokens):
tokenizer: Tokenizer,
suppress_tokens: Optional[List[int]],
) -> Optional[List[int]]:
if not suppress_tokens or -1 in suppress_tokens: if not suppress_tokens or -1 in suppress_tokens:
return suppress_tokens return suppress_tokens
@@ -1208,7 +881,7 @@ def get_suppressed_tokens(
return sorted(set(suppress_tokens)) return sorted(set(suppress_tokens))
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
# merge prepended punctuations # merge prepended punctuations
i = len(alignment) - 2 i = len(alignment) - 2
j = len(alignment) - 1 j = len(alignment) - 1

View File

@@ -1,37 +1,25 @@
import logging import logging
import os import os
import re
from typing import List, Optional from typing import Optional
import huggingface_hub import huggingface_hub
import requests import requests
from tqdm.auto import tqdm from tqdm.auto import tqdm
_MODELS = { _MODELS = (
"tiny.en": "Systran/faster-whisper-tiny.en", "tiny.en",
"tiny": "Systran/faster-whisper-tiny", "tiny",
"base.en": "Systran/faster-whisper-base.en", "base.en",
"base": "Systran/faster-whisper-base", "base",
"small.en": "Systran/faster-whisper-small.en", "small.en",
"small": "Systran/faster-whisper-small", "small",
"medium.en": "Systran/faster-whisper-medium.en", "medium.en",
"medium": "Systran/faster-whisper-medium", "medium",
"large-v1": "Systran/faster-whisper-large-v1", "large-v1",
"large-v2": "Systran/faster-whisper-large-v2", "large-v2",
"large-v3": "Systran/faster-whisper-large-v3", )
"large": "Systran/faster-whisper-large-v3",
"distil-large-v2": "Systran/faster-distil-whisper-large-v2",
"distil-medium.en": "Systran/faster-distil-whisper-medium.en",
"distil-small.en": "Systran/faster-distil-whisper-small.en",
"distil-large-v3": "Systran/faster-distil-whisper-large-v3",
}
def available_models() -> List[str]:
"""Returns the names of available models."""
return list(_MODELS.keys())
def get_assets_path(): def get_assets_path():
@@ -45,18 +33,18 @@ def get_logger():
def download_model( def download_model(
size_or_id: str, size: str,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
): ):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub. """Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
The model is downloaded from https://huggingface.co/guillaumekln.
Args: Args:
size_or_id: Size of the model to download from https://huggingface.co/Systran size: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
(tiny, tiny.en, base, base.en, small, small.en medium, medium.en, large-v1, large-v2, medium, medium.en, large-v1, or large-v2).
large-v3, large), or a CTranslate2-converted model ID from the Hugging Face Hub
(e.g. Systran/faster-whisper-large-v3).
output_dir: Directory where the model should be saved. If not set, the model is saved in output_dir: Directory where the model should be saved. If not set, the model is saved in
the cache directory. the cache directory.
local_files_only: If True, avoid downloading the file and return the path to the local local_files_only: If True, avoid downloading the file and return the path to the local
@@ -69,19 +57,15 @@ def download_model(
Raises: Raises:
ValueError: if the model size is invalid. ValueError: if the model size is invalid.
""" """
if re.match(r".*/.*", size_or_id): if size not in _MODELS:
repo_id = size_or_id
else:
repo_id = _MODELS.get(size_or_id)
if repo_id is None:
raise ValueError( raise ValueError(
"Invalid model size '%s', expected one of: %s" "Invalid model size '%s', expected one of: %s" % (size, ", ".join(_MODELS))
% (size_or_id, ", ".join(_MODELS.keys()))
) )
repo_id = "guillaumekln/faster-whisper-%s" % size
allow_patterns = [ allow_patterns = [
"config.json", "config.json",
"preprocessor_config.json",
"model.bin", "model.bin",
"tokenizer.json", "tokenizer.json",
"vocabulary.*", "vocabulary.*",
@@ -147,10 +131,3 @@ class disabled_tqdm(tqdm):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs["disable"] = True kwargs["disable"] = True
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)

View File

@@ -1,3 +1,3 @@
"""Version information.""" """Version information."""
__version__ = "1.0.1" __version__ = "0.6.0"

View File

@@ -1,5 +1,5 @@
av==11.* av==10.*
ctranslate2>=4.0,<5 ctranslate2>=3.10,<4
huggingface_hub>=0.13 huggingface_hub>=0.13
tokenizers>=0.13,<0.16 tokenizers==0.13.*
onnxruntime>=1.14,<2 onnxruntime>=1.14,<2

View File

@@ -37,7 +37,7 @@ setup(
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
author="Guillaume Klein", author="Guillaume Klein",
url="https://github.com/SYSTRAN/faster-whisper", url="https://github.com/guillaumekln/faster-whisper",
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@@ -3,11 +3,6 @@ import os
from faster_whisper import WhisperModel, decode_audio from faster_whisper import WhisperModel, decode_audio
def test_supported_languages():
model = WhisperModel("tiny.en")
assert model.supported_languages == ["en"]
def test_transcribe(jfk_path): def test_transcribe(jfk_path):
model = WhisperModel("tiny") model = WhisperModel("tiny")
segments, info = model.transcribe(jfk_path, word_timestamps=True) segments, info = model.transcribe(jfk_path, word_timestamps=True)
@@ -39,24 +34,6 @@ def test_transcribe(jfk_path):
assert segment.end == segment.words[-1].end assert segment.end == segment.words[-1].end
def test_prefix_with_timestamps(jfk_path):
model = WhisperModel("tiny")
segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans")
segments = list(segments)
assert len(segments) == 1
segment = segments[0]
assert segment.text == (
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
assert segment.start == 0
assert 10 < segment.end < 11
def test_vad(jfk_path): def test_vad(jfk_path):
model = WhisperModel("tiny") model = WhisperModel("tiny")
segments, info = model.transcribe( segments, info = model.transcribe(

View File

@@ -1,12 +1,6 @@
import os import os
from faster_whisper import available_models, download_model from faster_whisper import download_model
def test_available_models():
models = available_models()
assert isinstance(models, list)
assert "tiny" in models
def test_download_model(tmpdir): def test_download_model(tmpdir):