diff --git a/README.md b/README.md index daee860..01417a9 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,44 @@ For reference, here's the time and memory usage that are required to transcribe *Executed with 8 threads on a Intel(R) Xeon(R) Gold 6226R.* +## Requirements + +* Python 3.8 or greater + +Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package. + +### GPU + +GPU execution requires the following NVIDIA libraries to be installed: + +* [cuBLAS for CUDA 11](https://developer.nvidia.com/cublas) +* [cuDNN 8 for CUDA 11](https://developer.nvidia.com/cudnn) + +There are multiple ways to install these libraries. The recommended way is described in the official NVIDIA documentation, but we also suggest other installation methods below. + +
+Other installation methods (click to expand) + +#### Use Docker + +The libraries are installed in this official NVIDIA Docker image: `nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04`. + +#### Install with `pip` (Linux only) + +On Linux these libraries can be installed with `pip`. Note that `LD_LIBRARY_PATH` must be set before launching Python. + +```bash +pip install nvidia-cublas-cu11 nvidia-cudnn-cu11 + +export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` +``` + +#### Download the libraries from Purfview's repository (Windows & Linux) + +Purfview's [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) provides the required NVIDIA libraries for Windows & Linux in a [single archive](https://github.com/Purfview/whisper-standalone-win/releases/tag/libs). Decompress the archive and place the libraries in a directory included in the `PATH`. + +
+ ## Installation The module can be installed from [PyPI](https://pypi.org/project/faster-whisper/): @@ -44,26 +82,29 @@ The module can be installed from [PyPI](https://pypi.org/project/faster-whisper/ pip install faster-whisper ``` -**Other installation methods:** +
+Other installation methods (click to expand) + +### Install the master branch ```bash -# Install the master branch: pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/refs/heads/master.tar.gz" +``` -# Install a specific commit: +### Install a specific commit + +```bash pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/a4f1cc8f11433e454c3934442b5e1a4ed5e865c3.tar.gz" ``` -### GPU support - -GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html). +
## Usage ```python from faster_whisper import WhisperModel -model_size = "large-v2" +model_size = "large-v3" # Run on GPU with FP16 model = WhisperModel(model_size, device="cuda", compute_type="float16") @@ -137,23 +178,24 @@ Here is a non exhaustive list of open-source projects using faster-whisper. Feel * [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper. * [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo. -* [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) contains the portable ready to run binaries of faster-whisper for Windows. +* [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) Standalone CLI executables of faster-whisper for Windows, Linux & macOS. * [asr-sd-pipeline](https://github.com/hedrergudene/asr-sd-pipeline) provides a scalable, modular, end to end multi-speaker speech to text solution implemented using AzureML pipelines. * [Open-Lyrics](https://github.com/zh-plus/Open-Lyrics) is a Python library that transcribes voice files using faster-whisper, and translates/polishes the resulting text into `.lrc` files in the desired language using OpenAI-GPT. +* [wscribe](https://github.com/geekodour/wscribe) is a flexible transcript generation tool supporting faster-whisper, it can export word level transcript and the exported transcript then can be edited with [wscribe-editor](https://github.com/geekodour/wscribe-editor) ## Model conversion -When loading a model from its size such as `WhisperModel("large-v2")`, the correspondig CTranslate2 model is automatically downloaded from the [Hugging Face Hub](https://huggingface.co/guillaumekln). +When loading a model from its size such as `WhisperModel("large-v3")`, the correspondig CTranslate2 model is automatically downloaded from the [Hugging Face Hub](https://huggingface.co/Systran). We also provide a script to convert any Whisper models compatible with the Transformers library. They could be the original OpenAI models or user fine-tuned models. -For example the command below converts the [original "large-v2" Whisper model](https://huggingface.co/openai/whisper-large-v2) and saves the weights in FP16: +For example the command below converts the [original "large-v3" Whisper model](https://huggingface.co/openai/whisper-large-v3) and saves the weights in FP16: ```bash pip install transformers[torch]>=4.23 -ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper-large-v2-ct2 \ - --copy_files tokenizer.json --quantization float16 +ct2-transformers-converter --model openai/whisper-large-v3 --output_dir whisper-large-v3-ct2 +--copy_files tokenizer.json preprocessor_config.json --quantization float16 ``` * The option `--model` accepts a model name on the Hub or a path to a model directory. @@ -161,6 +203,18 @@ ct2-transformers-converter --model openai/whisper-large-v2 --output_dir whisper- Models can also be converted from the code. See the [conversion API](https://opennmt.net/CTranslate2/python/ctranslate2.converters.TransformersConverter.html). +### Load a converted model + +1. Directly load the model from a local directory: +```python +model = faster_whisper.WhisperModel("whisper-large-v3-ct2") +``` + +2. [Upload your model to the Hugging Face Hub](https://huggingface.co/docs/transformers/model_sharing#upload-with-the-web-interface) and load it from its name: +```python +model = faster_whisper.WhisperModel("username/whisper-large-v3-ct2") +``` + ## Comparing performance against other implementations If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular: diff --git a/faster_whisper/__init__.py b/faster_whisper/__init__.py index e2fe00d..9b56a39 100644 --- a/faster_whisper/__init__.py +++ b/faster_whisper/__init__.py @@ -1,9 +1,10 @@ from faster_whisper.audio import decode_audio from faster_whisper.transcribe import WhisperModel -from faster_whisper.utils import download_model, format_timestamp +from faster_whisper.utils import available_models, download_model, format_timestamp from faster_whisper.version import __version__ __all__ = [ + "available_models", "decode_audio", "WhisperModel", "download_model", diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index fbecc48..3190619 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -6,6 +6,7 @@ system dependencies. FFmpeg does not need to be installed on the system. However, the API is quite low-level so we need to manipulate audio frames directly. """ +import gc import io import itertools @@ -42,7 +43,7 @@ def decode_audio( raw_buffer = io.BytesIO() dtype = None - with av.open(input_file, metadata_errors="ignore") as container: + with av.open(input_file, mode="r", metadata_errors="ignore") as container: frames = container.decode(audio=0) frames = _ignore_invalid_frames(frames) frames = _group_frames(frames, 500000) @@ -53,6 +54,11 @@ def decode_audio( dtype = array.dtype raw_buffer.write(array) + # It appears that some objects related to the resampler are not freed + # unless the garbage collector is manually run. + del resampler + gc.collect() + audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype) # Convert s16 back to f32. diff --git a/faster_whisper/tokenizer.py b/faster_whisper/tokenizer.py index b040044..c3b13b4 100644 --- a/faster_whisper/tokenizer.py +++ b/faster_whisper/tokenizer.py @@ -19,15 +19,21 @@ class Tokenizer: self.tokenizer = tokenizer if multilingual: + if task not in _TASKS: + raise ValueError( + "'%s' is not a valid task (accepted tasks: %s)" + % (task, ", ".join(_TASKS)) + ) + + if language not in _LANGUAGE_CODES: + raise ValueError( + "'%s' is not a valid language code (accepted language codes: %s)" + % (language, ", ".join(_LANGUAGE_CODES)) + ) + self.task = self.tokenizer.token_to_id("<|%s|>" % task) - if self.task is None: - raise ValueError("%s is not a valid task" % task) - - self.language_code = language self.language = self.tokenizer.token_to_id("<|%s|>" % language) - if self.language is None: - raise ValueError("%s is not a valid language code" % language) - + self.language_code = language else: self.task = None self.language = None @@ -102,7 +108,7 @@ class Tokenizer: def split_to_word_tokens( self, tokens: List[int] ) -> Tuple[List[str], List[List[int]]]: - if self.language_code in {"zh", "ja", "th", "lo", "my"}: + if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}: # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points @@ -161,3 +167,112 @@ class Tokenizer: word_tokens[-1].extend(subword_tokens) return words, word_tokens + + +_TASKS = ( + "transcribe", + "translate", +) + +_LANGUAGE_CODES = ( + "af", + "am", + "ar", + "as", + "az", + "ba", + "be", + "bg", + "bn", + "bo", + "br", + "bs", + "ca", + "cs", + "cy", + "da", + "de", + "el", + "en", + "es", + "et", + "eu", + "fa", + "fi", + "fo", + "fr", + "gl", + "gu", + "ha", + "haw", + "he", + "hi", + "hr", + "ht", + "hu", + "hy", + "id", + "is", + "it", + "ja", + "jw", + "ka", + "kk", + "km", + "kn", + "ko", + "la", + "lb", + "ln", + "lo", + "lt", + "lv", + "mg", + "mi", + "mk", + "ml", + "mn", + "mr", + "ms", + "mt", + "my", + "ne", + "nl", + "nn", + "no", + "oc", + "pa", + "pl", + "ps", + "pt", + "ro", + "ru", + "sa", + "sd", + "si", + "sk", + "sl", + "sn", + "so", + "sq", + "sr", + "su", + "sv", + "sw", + "ta", + "te", + "tg", + "th", + "tk", + "tl", + "tr", + "tt", + "uk", + "ur", + "uz", + "vi", + "yi", + "yo", + "zh", + "yue", +) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 1ef655a..227e529 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -1,8 +1,10 @@ import itertools +import json import logging import os import zlib +from inspect import signature from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union import ctranslate2 @@ -11,7 +13,7 @@ import tokenizers from faster_whisper.audio import decode_audio from faster_whisper.feature_extractor import FeatureExtractor -from faster_whisper.tokenizer import Tokenizer +from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer from faster_whisper.utils import download_model, format_timestamp, get_logger from faster_whisper.vad import ( SpeechTimestampsMap, @@ -47,10 +49,13 @@ class TranscriptionOptions(NamedTuple): best_of: int patience: float length_penalty: float + repetition_penalty: float + no_repeat_ngram_size: int log_prob_threshold: Optional[float] no_speech_threshold: Optional[float] compression_ratio_threshold: Optional[float] condition_on_previous_text: bool + prompt_reset_on_temperature: float temperatures: List[float] initial_prompt: Optional[Union[str, Iterable[int]]] prefix: Optional[str] @@ -67,6 +72,7 @@ class TranscriptionInfo(NamedTuple): language: str language_probability: float duration: float + duration_after_vad: float all_language_probs: Optional[List[Tuple[str, float]]] transcription_options: TranscriptionOptions vad_options: VadOptions @@ -88,8 +94,9 @@ class WhisperModel: Args: model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, - small, small.en, medium, medium.en, large-v1, or large-v2) or a path to a converted - model directory. When a size is configured, the converted model is downloaded + small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a + converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub. + When a size or a model ID is configured, the converted model is downloaded from the Hugging Face Hub. device: Device to use for computation ("cpu", "cuda", "auto"). device_index: Device ID to use. @@ -137,7 +144,8 @@ class WhisperModel: "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") ) - self.feature_extractor = FeatureExtractor() + self.feat_kwargs = self._get_feature_kwargs(model_path) + self.feature_extractor = FeatureExtractor(**self.feat_kwargs) self.num_samples_per_token = self.feature_extractor.hop_length * 2 self.frames_per_second = ( self.feature_extractor.sampling_rate // self.feature_extractor.hop_length @@ -149,6 +157,27 @@ class WhisperModel: self.time_precision = 0.02 self.max_length = 448 + @property + def supported_languages(self) -> List[str]: + """The languages supported by the model.""" + return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"] + + def _get_feature_kwargs(self, model_path) -> dict: + preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json") + config = {} + if os.path.isfile(preprocessor_config_file): + try: + with open(preprocessor_config_file, "r", encoding="utf-8") as json_file: + config = json.load(json_file) + valid_keys = signature(FeatureExtractor.__init__).parameters.keys() + config = {k: v for k, v in config.items() if k in valid_keys} + except json.JSONDecodeError as e: + self.logger.warning( + "Could not load preprocessor_config.json: %s", str(e) + ) + + return config + def transcribe( self, audio: Union[str, BinaryIO, np.ndarray], @@ -158,6 +187,8 @@ class WhisperModel: best_of: int = 5, patience: float = 1, length_penalty: float = 1, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, temperature: Union[float, List[float], Tuple[float, ...]] = [ 0.0, 0.2, @@ -170,6 +201,7 @@ class WhisperModel: log_prob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, + prompt_reset_on_temperature: float = 0.5, initial_prompt: Optional[Union[str, Iterable[int]]] = None, prefix: Optional[str] = None, suppress_blank: bool = True, @@ -194,6 +226,9 @@ class WhisperModel: best_of: Number of candidates when sampling with non-zero temperature. patience: Beam search patience factor. length_penalty: Exponential length penalty constant. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). temperature: Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `compression_ratio_threshold` or `log_prob_threshold`. @@ -208,6 +243,8 @@ class WhisperModel: as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + prompt_reset_on_temperature: Resets prompt if temperature is above this value. + Arg has effect only if condition_on_previous_text is True. initial_prompt: Optional text string or iterable of token ids to provide as a prompt for the first window. prefix: Optional text to provide as a prefix for the first window. @@ -240,6 +277,7 @@ class WhisperModel: audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate + duration_after_vad = duration self.logger.info( "Processing audio with duration %s", format_timestamp(duration) @@ -252,10 +290,11 @@ class WhisperModel: vad_parameters = VadOptions(**vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) audio = collect_chunks(audio, speech_chunks) + duration_after_vad = audio.shape[0] / sampling_rate self.logger.info( "VAD filter removed %s of audio", - format_timestamp(duration - (audio.shape[0] / sampling_rate)), + format_timestamp(duration - duration_after_vad), ) if self.logger.isEnabledFor(logging.DEBUG): @@ -300,6 +339,13 @@ class WhisperModel: language_probability, ) else: + if not self.model.is_multilingual and language != "en": + self.logger.warning( + "The current model is English-only but the language parameter is set to '%s'; " + "using 'en' instead." % language + ) + language = "en" + language_probability = 1 tokenizer = Tokenizer( @@ -314,10 +360,13 @@ class WhisperModel: best_of=best_of, patience=patience, length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, log_prob_threshold=log_prob_threshold, no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, condition_on_previous_text=condition_on_previous_text, + prompt_reset_on_temperature=prompt_reset_on_temperature, temperatures=( temperature if isinstance(temperature, (list, tuple)) else [temperature] ), @@ -341,6 +390,7 @@ class WhisperModel: language=language, language_probability=language_probability, duration=duration, + duration_after_vad=duration_after_vad, transcription_options=options, vad_options=vad_parameters, all_language_probs=all_language_probs, @@ -370,6 +420,7 @@ class WhisperModel: else: all_tokens.extend(options.initial_prompt) + last_speech_timestamp = 0.0 while seek < content_frames: time_offset = seek * self.feature_extractor.time_per_frame segment = features[:, seek : seek + self.feature_extractor.nb_max_frames] @@ -391,7 +442,7 @@ class WhisperModel: prefix=options.prefix if seek == 0 else None, ) - if encoder_output is None: + if seek > 0 or encoder_output is None: encoder_output = self.encode(segment) ( @@ -511,12 +562,14 @@ class WhisperModel: segment_size, options.prepend_punctuations, options.append_punctuations, + last_speech_timestamp=last_speech_timestamp, ) word_end_timestamps = [ w["end"] for s in current_segments for w in s["words"] ] - + if len(word_end_timestamps) > 0: + last_speech_timestamp = word_end_timestamps[-1] if not single_timestamp_ending and len(word_end_timestamps) > 0: seek_shift = round( (word_end_timestamps[-1] - time_offset) * self.frames_per_second @@ -525,8 +578,6 @@ class WhisperModel: if seek_shift > 0: seek = previous_seek + seek_shift - encoder_output = None - for segment in current_segments: tokens = segment["tokens"] text = tokenizer.decode(tokens) @@ -563,7 +614,17 @@ class WhisperModel: ), ) - if not options.condition_on_previous_text or temperature > 0.5: + if ( + not options.condition_on_previous_text + or temperature > options.prompt_reset_on_temperature + ): + if options.condition_on_previous_text: + self.logger.debug( + "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f", + temperature, + options.prompt_reset_on_temperature, + ) + prompt_reset_since = len(all_tokens) def encode(self, features: np.ndarray) -> ctranslate2.StorageView: @@ -583,10 +644,9 @@ class WhisperModel: tokenizer: Tokenizer, options: TranscriptionOptions, ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: - result = None - avg_logprob = None - final_temperature = None - compression_ratio = None + decode_result = None + all_results = [] + below_cr_threshold_results = [] max_initial_timestamp_index = int( round(options.max_initial_timestamp / self.time_precision) @@ -606,11 +666,12 @@ class WhisperModel: "patience": options.patience, } - final_temperature = temperature result = self.model.generate( encoder_output, [prompt], length_penalty=options.length_penalty, + repetition_penalty=options.repetition_penalty, + no_repeat_ngram_size=options.no_repeat_ngram_size, max_length=self.max_length, return_scores=True, return_no_speech_prob=True, @@ -630,20 +691,28 @@ class WhisperModel: text = tokenizer.decode(tokens).strip() compression_ratio = get_compression_ratio(text) + decode_result = ( + result, + avg_logprob, + temperature, + compression_ratio, + ) + all_results.append(decode_result) + needs_fallback = False - if ( - options.compression_ratio_threshold is not None - and compression_ratio > options.compression_ratio_threshold - ): - needs_fallback = True # too repetitive + if options.compression_ratio_threshold is not None: + if compression_ratio > options.compression_ratio_threshold: + needs_fallback = True # too repetitive - self.logger.debug( - "Compression ratio threshold is not met with temperature %.1f (%f > %f)", - temperature, - compression_ratio, - options.compression_ratio_threshold, - ) + self.logger.debug( + "Compression ratio threshold is not met with temperature %.1f (%f > %f)", + temperature, + compression_ratio, + options.compression_ratio_threshold, + ) + else: + below_cr_threshold_results.append(decode_result) if ( options.log_prob_threshold is not None @@ -658,10 +727,28 @@ class WhisperModel: options.log_prob_threshold, ) + if ( + options.no_speech_threshold is not None + and result.no_speech_prob > options.no_speech_threshold + ): + needs_fallback = False # silence + if not needs_fallback: break + else: + # all failed, select the result with the highest average log probability + decode_result = max( + below_cr_threshold_results or all_results, key=lambda x: x[1] + ) + # to pass final temperature for prompt_reset_on_temperature + decode_result = ( + decode_result[0], + decode_result[1], + temperature, + decode_result[3], + ) - return result, avg_logprob, final_temperature, compression_ratio + return decode_result def get_prompt( self, @@ -685,6 +772,8 @@ class WhisperModel: prefix_tokens = tokenizer.encode(" " + prefix.strip()) if len(prefix_tokens) >= self.max_length // 2: prefix_tokens = prefix_tokens[: self.max_length // 2 - 1] + if not without_timestamps: + prompt.append(tokenizer.timestamp_begin) prompt.extend(prefix_tokens) return prompt @@ -697,7 +786,8 @@ class WhisperModel: num_frames: int, prepend_punctuations: str, append_punctuations: str, - ): + last_speech_timestamp: float, + ) -> None: if len(segments) == 0: return @@ -710,6 +800,24 @@ class WhisperModel: alignment = self.find_alignment( tokenizer, text_tokens, encoder_output, num_frames ) + word_durations = np.array([word["end"] - word["start"] for word in alignment]) + word_durations = word_durations[word_durations.nonzero()] + median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries + # are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i]["end"] - alignment[i]["start"] > max_duration: + if alignment[i]["word"] in sentence_end_marks: + alignment[i]["end"] = alignment[i]["start"] + max_duration + elif alignment[i - 1]["word"] in sentence_end_marks: + alignment[i]["start"] = alignment[i]["end"] - max_duration + merge_punctuations(alignment, prepend_punctuations, append_punctuations) time_offset = ( @@ -740,10 +848,51 @@ class WhisperModel: saved_tokens += len(timing["tokens"]) word_index += 1 + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. if len(words) > 0: - # adjust the segment-level timestamps based on the word-level timestamps - segment["start"] = words[0]["start"] - segment["end"] = words[-1]["end"] + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max( + words[1]["end"] / 2, words[1]["end"] - max_duration + ) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if ( + segment["start"] < words[0]["end"] + and segment["start"] - 0.5 > words[0]["start"] + ): + words[0]["start"] = max( + 0, min(words[0]["end"] - median_duration, segment["start"]) + ) + else: + segment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. + if ( + segment["end"] > words[-1]["start"] + and segment["end"] + 0.5 < words[-1]["end"] + ): + words[-1]["end"] = max( + words[-1]["start"] + median_duration, segment["end"] + ) + else: + segment["end"] = words[-1]["end"] + + last_speech_timestamp = segment["end"] segment["words"] = words @@ -775,6 +924,13 @@ class WhisperModel: words, word_tokens = tokenizer.split_to_word_tokens( text_tokens + [tokenizer.eot] ) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return [] word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) if len(word_boundaries) <= 1: return [] @@ -788,22 +944,6 @@ class WhisperModel: for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) ] - # hack: ensure the first and second word is not longer than twice the median word duration. - # a better segmentation algorithm based on VAD should be able to replace this. - word_durations = end_times - start_times - word_durations = word_durations[word_durations.nonzero()] - if len(word_durations) > 0: - median_duration = np.median(word_durations) - max_duration = median_duration * 2 - if len(word_durations) >= 2 and word_durations[1] > max_duration: - boundary = max(end_times[2] / 2, end_times[2] - max_duration) - end_times[0] = start_times[1] = boundary - if ( - len(word_durations) >= 1 - and end_times[0] - start_times[0] > max_duration - ): - start_times[0] = max(0, end_times[0] - max_duration) - return [ dict( word=word, tokens=tokens, start=start, end=end, probability=probability @@ -860,7 +1000,10 @@ def get_compression_ratio(text: str) -> float: return len(text_bytes) / len(zlib.compress(text_bytes)) -def get_suppressed_tokens(tokenizer, suppress_tokens): +def get_suppressed_tokens( + tokenizer: Tokenizer, + suppress_tokens: Optional[List[int]], +) -> Optional[List[int]]: if not suppress_tokens or -1 in suppress_tokens: return suppress_tokens @@ -881,7 +1024,7 @@ def get_suppressed_tokens(tokenizer, suppress_tokens): return sorted(set(suppress_tokens)) -def merge_punctuations(alignment: List[dict], prepended: str, appended: str): +def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: # merge prepended punctuations i = len(alignment) - 2 j = len(alignment) - 1 diff --git a/faster_whisper/utils.py b/faster_whisper/utils.py index 950b0da..343a635 100644 --- a/faster_whisper/utils.py +++ b/faster_whisper/utils.py @@ -1,25 +1,33 @@ import logging import os +import re -from typing import Optional +from typing import List, Optional import huggingface_hub import requests from tqdm.auto import tqdm -_MODELS = ( - "tiny.en", - "tiny", - "base.en", - "base", - "small.en", - "small", - "medium.en", - "medium", - "large-v1", - "large-v2", -) +_MODELS = { + "tiny.en": "Systran/faster-whisper-tiny.en", + "tiny": "Systran/faster-whisper-tiny", + "base.en": "Systran/faster-whisper-base.en", + "base": "Systran/faster-whisper-base", + "small.en": "Systran/faster-whisper-small.en", + "small": "Systran/faster-whisper-small", + "medium.en": "Systran/faster-whisper-medium.en", + "medium": "Systran/faster-whisper-medium", + "large-v1": "Systran/faster-whisper-large-v1", + "large-v2": "Systran/faster-whisper-large-v2", + "large-v3": "Systran/faster-whisper-large-v3", + "large": "Systran/faster-whisper-large-v3", +} + + +def available_models() -> List[str]: + """Returns the names of available models.""" + return list(_MODELS.keys()) def get_assets_path(): @@ -33,18 +41,18 @@ def get_logger(): def download_model( - size: str, + size_or_id: str, output_dir: Optional[str] = None, local_files_only: bool = False, cache_dir: Optional[str] = None, ): """Downloads a CTranslate2 Whisper model from the Hugging Face Hub. - The model is downloaded from https://huggingface.co/guillaumekln. - Args: - size: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en, - medium, medium.en, large-v1, or large-v2). + size_or_id: Size of the model to download from https://huggingface.co/guillaumekln + (tiny, tiny.en, base, base.en, small, small.en medium, medium.en, large-v1, large-v2, + large-v3, large), or a CTranslate2-converted model ID from the Hugging Face Hub + (e.g. Systran/faster-whisper-large-v3). output_dir: Directory where the model should be saved. If not set, the model is saved in the cache directory. local_files_only: If True, avoid downloading the file and return the path to the local @@ -57,15 +65,19 @@ def download_model( Raises: ValueError: if the model size is invalid. """ - if size not in _MODELS: - raise ValueError( - "Invalid model size '%s', expected one of: %s" % (size, ", ".join(_MODELS)) - ) - - repo_id = "guillaumekln/faster-whisper-%s" % size + if re.match(r".*/.*", size_or_id): + repo_id = size_or_id + else: + repo_id = _MODELS.get(size_or_id) + if repo_id is None: + raise ValueError( + "Invalid model size '%s', expected one of: %s" + % (size_or_id, ", ".join(_MODELS.keys())) + ) allow_patterns = [ "config.json", + "preprocessor_config.json", "model.bin", "tokenizer.json", "vocabulary.*", diff --git a/faster_whisper/version.py b/faster_whisper/version.py index bf288f0..e1f6d31 100644 --- a/faster_whisper/version.py +++ b/faster_whisper/version.py @@ -1,3 +1,3 @@ """Version information.""" -__version__ = "0.6.0" +__version__ = "0.10.0" diff --git a/requirements.txt b/requirements.txt index 4dd8bac..ba0da20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ av==10.* -ctranslate2>=3.10,<4 +ctranslate2>=3.22,<4 huggingface_hub>=0.13 -tokenizers==0.13.* +tokenizers>=0.13,<0.16 onnxruntime>=1.14,<2 diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 6ecf2c4..d30a0fb 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -3,6 +3,11 @@ import os from faster_whisper import WhisperModel, decode_audio +def test_supported_languages(): + model = WhisperModel("tiny.en") + assert model.supported_languages == ["en"] + + def test_transcribe(jfk_path): model = WhisperModel("tiny") segments, info = model.transcribe(jfk_path, word_timestamps=True) @@ -34,6 +39,24 @@ def test_transcribe(jfk_path): assert segment.end == segment.words[-1].end +def test_prefix_with_timestamps(jfk_path): + model = WhisperModel("tiny") + segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans") + segments = list(segments) + + assert len(segments) == 1 + + segment = segments[0] + + assert segment.text == ( + " And so my fellow Americans ask not what your country can do for you, " + "ask what you can do for your country." + ) + + assert segment.start == 0 + assert 10 < segment.end < 11 + + def test_vad(jfk_path): model = WhisperModel("tiny") segments, info = model.transcribe( diff --git a/tests/test_utils.py b/tests/test_utils.py index ee404bf..bb488fe 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,12 @@ import os -from faster_whisper import download_model +from faster_whisper import available_models, download_model + + +def test_available_models(): + models = available_models() + assert isinstance(models, list) + assert "tiny" in models def test_download_model(tmpdir):