Merge branch 'master' into prompt
This commit is contained in:
46
README.md
46
README.md
@@ -8,6 +8,8 @@ This implementation is up to 4 times faster than [openai/whisper](https://github
|
||||
|
||||
## Benchmark
|
||||
|
||||
### Whisper
|
||||
|
||||
For reference, here's the time and memory usage that are required to transcribe [**13 minutes**](https://www.youtube.com/watch?v=0u7tTptBo9I) of audio using different implementations:
|
||||
|
||||
* [openai/whisper](https://github.com/openai/whisper)@[6dea21fd](https://github.com/openai/whisper/commit/6dea21fd7f7253bfe450f1e2512a0fe47ee2d258)
|
||||
@@ -36,6 +38,33 @@ For reference, here's the time and memory usage that are required to transcribe
|
||||
|
||||
*Executed with 8 threads on a Intel(R) Xeon(R) Gold 6226R.*
|
||||
|
||||
|
||||
### Distil-whisper
|
||||
|
||||
| Implementation | Precision | Beam size | Time | Gigaspeech WER |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| distil-whisper/distil-large-v2 | fp16 | 4 |- | 10.36 |
|
||||
| [faster-distil-large-v2](https://huggingface.co/Systran/faster-distil-whisper-large-v2) | fp16 | 5 | - | 10.28 |
|
||||
| distil-whisper/distil-medium.en | fp16 | 4 | - | 11.21 |
|
||||
| [faster-distil-medium.en](https://huggingface.co/Systran/faster-distil-whisper-medium.en) | fp16 | 5 | - | 11.21 |
|
||||
|
||||
*Executed with CUDA 11.4 on a NVIDIA 3090.*
|
||||
|
||||
<details>
|
||||
<summary>testing details (click to expand)</summary>
|
||||
|
||||
For `distil-whisper/distil-large-v2`, the WER is tested with code sample from [link](https://huggingface.co/distil-whisper/distil-large-v2#evaluation). for `faster-distil-whisper`, the WER is tested with setting:
|
||||
```python
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model_size = "distil-large-v2"
|
||||
# model_size = "distil-medium.en"
|
||||
# Run on GPU with FP16
|
||||
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
segments, info = model.transcribe("audio.mp3", beam_size=5, language="en")
|
||||
```
|
||||
</details>
|
||||
|
||||
## Requirements
|
||||
|
||||
* Python 3.8 or greater
|
||||
@@ -101,6 +130,8 @@ pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/
|
||||
|
||||
## Usage
|
||||
|
||||
### Faster-whisper
|
||||
|
||||
```python
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
@@ -128,6 +159,18 @@ for segment in segments:
|
||||
segments, _ = model.transcribe("audio.mp3")
|
||||
segments = list(segments) # The transcription will actually run here.
|
||||
```
|
||||
### Faster-distil-whisper
|
||||
For usage of `faster-ditil-whisper`, please refer to: https://github.com/guillaumekln/faster-whisper/issues/533
|
||||
|
||||
```python
|
||||
model_size = "distil-large-v2"
|
||||
# model_size = "distil-medium.en"
|
||||
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
segments, info = model.transcribe("audio.mp3", beam_size=5,
|
||||
language="en", max_new_tokens=128, condition_on_previous_text=False)
|
||||
|
||||
```
|
||||
NOTE: emprically, `condition_on_previous_text=True` will degrade the performance of `faster-distil-whisper` for long audio. Degradation on the first chunk was observed with `initial_prompt` too.
|
||||
|
||||
### Word-level timestamps
|
||||
|
||||
@@ -182,6 +225,9 @@ Here is a non exhaustive list of open-source projects using faster-whisper. Feel
|
||||
* [asr-sd-pipeline](https://github.com/hedrergudene/asr-sd-pipeline) provides a scalable, modular, end to end multi-speaker speech to text solution implemented using AzureML pipelines.
|
||||
* [Open-Lyrics](https://github.com/zh-plus/Open-Lyrics) is a Python library that transcribes voice files using faster-whisper, and translates/polishes the resulting text into `.lrc` files in the desired language using OpenAI-GPT.
|
||||
* [wscribe](https://github.com/geekodour/wscribe) is a flexible transcript generation tool supporting faster-whisper, it can export word level transcript and the exported transcript then can be edited with [wscribe-editor](https://github.com/geekodour/wscribe-editor)
|
||||
* [aTrain](https://github.com/BANDAS-Center/aTrain) is a graphical user interface implementation of faster-whisper developed at the BANDAS-Center at the University of Graz for transcription and diarization in Windows ([Windows Store App](https://apps.microsoft.com/detail/atrain/9N15Q44SZNS2)) and Linux.
|
||||
* [Whisper-Streaming](https://github.com/ufal/whisper_streaming) implements real-time mode for offline Whisper-like speech-to-text models with faster-whisper as the most recommended back-end. It implements a streaming policy with self-adaptive latency based on the actual source complexity, and demonstrates the state of the art.
|
||||
* [WhisperLive](https://github.com/collabora/WhisperLive) is a nearly-live implementation of OpenAI's Whisper which uses faster-whisper as the backend to transcribe audio in real-time.
|
||||
|
||||
## Model conversion
|
||||
|
||||
|
||||
@@ -142,11 +142,15 @@ class FeatureExtractor:
|
||||
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
|
||||
return data.T
|
||||
|
||||
def __call__(self, waveform, padding=True):
|
||||
def __call__(self, waveform, padding=True, chunk_length=None):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of the provided audio, gives similar results
|
||||
whisper's original torch implementation with 1e-5 tolerance.
|
||||
"""
|
||||
if chunk_length is not None:
|
||||
self.n_samples = chunk_length * self.sampling_rate
|
||||
self.nb_max_frames = self.n_samples // self.hop_length
|
||||
|
||||
if padding:
|
||||
waveform = np.pad(waveform, [(0, self.n_samples)])
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ class TranscriptionOptions(NamedTuple):
|
||||
word_timestamps: bool
|
||||
prepend_punctuations: str
|
||||
append_punctuations: str
|
||||
max_new_tokens: Optional[int]
|
||||
|
||||
|
||||
class TranscriptionInfo(NamedTuple):
|
||||
@@ -213,6 +214,8 @@ class WhisperModel:
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
vad_filter: bool = False,
|
||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
chunk_length: Optional[int] = None,
|
||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||
"""Transcribes an input file.
|
||||
|
||||
@@ -264,6 +267,10 @@ class WhisperModel:
|
||||
https://github.com/snakers4/silero-vad.
|
||||
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
||||
parameters and default values in the class `VadOptions`).
|
||||
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
|
||||
the maximum will be set by the default max_length.
|
||||
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
||||
default chunk_length of the FeatureExtractor.
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
@@ -313,7 +320,7 @@ class WhisperModel:
|
||||
else:
|
||||
speech_chunks = None
|
||||
|
||||
features = self.feature_extractor(audio)
|
||||
features = self.feature_extractor(audio, chunk_length=chunk_length)
|
||||
|
||||
encoder_output = None
|
||||
all_language_probs = None
|
||||
@@ -379,6 +386,7 @@ class WhisperModel:
|
||||
word_timestamps=word_timestamps,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||
@@ -651,6 +659,21 @@ class WhisperModel:
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
)
|
||||
if options.max_new_tokens is not None:
|
||||
max_length = len(prompt) + options.max_new_tokens
|
||||
else:
|
||||
max_length = self.max_length
|
||||
|
||||
if max_length > self.max_length:
|
||||
raise ValueError(
|
||||
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
|
||||
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
|
||||
f"and `max_new_tokens` is: {max_length}. This exceeds the "
|
||||
f"`max_length` of the Whisper model: {self.max_length}. "
|
||||
"You should either reduce the length of your prompt, or "
|
||||
"reduce the value of `max_new_tokens`, "
|
||||
f"so that their combined length is less that {self.max_length}."
|
||||
)
|
||||
|
||||
for temperature in options.temperatures:
|
||||
if temperature > 0:
|
||||
@@ -672,7 +695,7 @@ class WhisperModel:
|
||||
length_penalty=options.length_penalty,
|
||||
repetition_penalty=options.repetition_penalty,
|
||||
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
||||
max_length=self.max_length,
|
||||
max_length=max_length,
|
||||
return_scores=True,
|
||||
return_no_speech_prob=True,
|
||||
suppress_blank=options.suppress_blank,
|
||||
@@ -730,6 +753,8 @@ class WhisperModel:
|
||||
if (
|
||||
options.no_speech_threshold is not None
|
||||
and result.no_speech_prob > options.no_speech_threshold
|
||||
and options.log_prob_threshold is not None
|
||||
and avg_logprob < options.log_prob_threshold
|
||||
):
|
||||
needs_fallback = False # silence
|
||||
|
||||
|
||||
@@ -22,6 +22,9 @@ _MODELS = {
|
||||
"large-v2": "Systran/faster-whisper-large-v2",
|
||||
"large-v3": "Systran/faster-whisper-large-v3",
|
||||
"large": "Systran/faster-whisper-large-v3",
|
||||
"distil-large-v2": "Systran/faster-distil-whisper-large-v2",
|
||||
"distil-medium.en": "Systran/faster-distil-whisper-medium.en",
|
||||
"distil-small.en": "Systran/faster-distil-whisper-small.en",
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user