diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..379b9ad --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to faster-whisper + +Contributions are welcome! Here are some pointers to help you install the library for development and validate your changes before submitting a pull request. + +## Install the library for development + +We recommend installing the module in editable mode with the `dev` extra requirements: + +```bash +git clone https://github.com/guillaumekln/faster-whisper.git +cd faster-whisper/ +pip install -e .[dev] +``` + +## Validate the changes before creating a pull request + +1. Make sure the existing tests are still passing (and consider adding new tests as well!): + +```bash +pytest tests/ +``` + +2. Reformat and validate the code with the following tools: + +```bash +black . +isort . +flake8 . +``` + +These steps are also run automatically in the CI when you open the pull request. diff --git a/MANIFEST.in b/MANIFEST.in index e2fff83..6f6187c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,3 @@ include faster_whisper/assets/silero_vad.onnx +include requirements.txt +include requirements.conversion.txt diff --git a/README.md b/README.md index 02a2be6..daee860 100644 --- a/README.md +++ b/README.md @@ -52,10 +52,6 @@ pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/ # Install a specific commit: pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/faster-whisper/archive/a4f1cc8f11433e454c3934442b5e1a4ed5e865c3.tar.gz" - -# Install for development: -git clone https://github.com/guillaumekln/faster-whisper.git -pip install -e faster-whisper/ ``` ### GPU support @@ -64,8 +60,6 @@ GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be inst ## Usage -### Library - ```python from faster_whisper import WhisperModel @@ -94,7 +88,7 @@ segments, _ = model.transcribe("audio.mp3") segments = list(segments) # The transcription will actually run here. ``` -#### Word-level timestamps +### Word-level timestamps ```python segments, _ = model.transcribe("audio.mp3", word_timestamps=True) @@ -104,7 +98,7 @@ for segment in segments: print("[%.2fs -> %.2fs] %s" % (word.start, word.end, word.word)) ``` -#### VAD filter +### VAD filter The library integrates the [Silero VAD](https://github.com/snakers4/silero-vad) model to filter out parts of the audio without speech: @@ -112,19 +106,40 @@ The library integrates the [Silero VAD](https://github.com/snakers4/silero-vad) segments, _ = model.transcribe("audio.mp3", vad_filter=True) ``` -The default behavior is conservative and only removes silence longer than 2 seconds. See the available VAD parameters and default values in the function [`get_speech_timestamps`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/vad.py). They can be customized with the dictionary argument `vad_parameters`: +The default behavior is conservative and only removes silence longer than 2 seconds. See the available VAD parameters and default values in the [source code](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/vad.py). They can be customized with the dictionary argument `vad_parameters`: ```python -segments, _ = model.transcribe("audio.mp3", vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500)) +segments, _ = model.transcribe( + "audio.mp3", + vad_filter=True, + vad_parameters=dict(min_silence_duration_ms=500), +) ``` -#### Going further +### Logging + +The library logging level can be configured like this: + +```python +import logging + +logging.basicConfig() +logging.getLogger("faster_whisper").setLevel(logging.DEBUG) +``` + +### Going further See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation. -### CLI +## Community integrations -You can use [jordimas/whisper-ctranslate2](https://github.com/jordimas/whisper-ctranslate2) to access `faster-whisper` through a CLI interface similar to what is offered by Whisper. +Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list! + +* [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper. +* [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo. +* [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) contains the portable ready to run binaries of faster-whisper for Windows. +* [asr-sd-pipeline](https://github.com/hedrergudene/asr-sd-pipeline) provides a scalable, modular, end to end multi-speaker speech to text solution implemented using AzureML pipelines. +* [Open-Lyrics](https://github.com/zh-plus/Open-Lyrics) is a Python library that transcribes voice files using faster-whisper, and translates/polishes the resulting text into `.lrc` files in the desired language using OpenAI-GPT. ## Model conversion diff --git a/faster_whisper/__init__.py b/faster_whisper/__init__.py index add677e..e2fe00d 100644 --- a/faster_whisper/__init__.py +++ b/faster_whisper/__init__.py @@ -1,10 +1,12 @@ from faster_whisper.audio import decode_audio from faster_whisper.transcribe import WhisperModel from faster_whisper.utils import download_model, format_timestamp +from faster_whisper.version import __version__ __all__ = [ "decode_audio", "WhisperModel", "download_model", "format_timestamp", + "__version__", ] diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index c5c5525..1ef655a 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -15,6 +15,7 @@ from faster_whisper.tokenizer import Tokenizer from faster_whisper.utils import download_model, format_timestamp, get_logger from faster_whisper.vad import ( SpeechTimestampsMap, + VadOptions, collect_chunks, get_speech_timestamps, ) @@ -28,18 +29,17 @@ class Word(NamedTuple): class Segment(NamedTuple): + id: int + seek: int start: float end: float text: str - words: Optional[List[Word]] - avg_log_prob: float + tokens: List[int] + temperature: float + avg_logprob: float + compression_ratio: float no_speech_prob: float - - -class AudioInfo(NamedTuple): - language: str - language_probability: float - duration: float + words: Optional[List[Word]] class TranscriptionOptions(NamedTuple): @@ -52,7 +52,7 @@ class TranscriptionOptions(NamedTuple): compression_ratio_threshold: Optional[float] condition_on_previous_text: bool temperatures: List[float] - initial_prompt: Optional[str] + initial_prompt: Optional[Union[str, Iterable[int]]] prefix: Optional[str] suppress_blank: bool suppress_tokens: Optional[List[int]] @@ -63,6 +63,15 @@ class TranscriptionOptions(NamedTuple): append_punctuations: str +class TranscriptionInfo(NamedTuple): + language: str + language_probability: float + duration: float + all_language_probs: Optional[List[Tuple[str, float]]] + transcription_options: TranscriptionOptions + vad_options: VadOptions + + class WhisperModel: def __init__( self, @@ -73,6 +82,7 @@ class WhisperModel: cpu_threads: int = 0, num_workers: int = 1, download_root: Optional[str] = None, + local_files_only: bool = False, ): """Initializes the Whisper model. @@ -94,15 +104,21 @@ class WhisperModel: having multiple workers enables true parallelism when running the model (concurrent calls to self.model.generate() will run in parallel). This can improve the global throughput at the cost of increased memory usage. - download_root: Directory where the model should be saved. If not set, the model - is saved in the standard Hugging Face cache directory. + download_root: Directory where the models should be saved. If not set, the models + are saved in the standard Hugging Face cache directory. + local_files_only: If True, avoid downloading the file and return the path to the + local cached file if it exists. """ self.logger = get_logger() if os.path.isdir(model_size_or_path): model_path = model_size_or_path else: - model_path = download_model(model_size_or_path, download_root) + model_path = download_model( + model_size_or_path, + local_files_only=local_files_only, + cache_dir=download_root, + ) self.model = ctranslate2.models.Whisper( model_path, @@ -154,7 +170,7 @@ class WhisperModel: log_prob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, + initial_prompt: Optional[Union[str, Iterable[int]]] = None, prefix: Optional[str] = None, suppress_blank: bool = True, suppress_tokens: Optional[List[int]] = [-1], @@ -164,8 +180,8 @@ class WhisperModel: prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", vad_filter: bool = False, - vad_parameters: Optional[dict] = None, - ) -> Tuple[Iterable[Segment], AudioInfo]: + vad_parameters: Optional[Union[dict, VadOptions]] = None, + ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. Arguments: @@ -192,7 +208,8 @@ class WhisperModel: as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - initial_prompt: Optional text to provide as a prompt for the first window. + initial_prompt: Optional text string or iterable of token ids to provide as a + prompt for the first window. prefix: Optional text to provide as a prefix for the first window. suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set @@ -208,14 +225,14 @@ class WhisperModel: vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio without speech. This step is using the Silero VAD model https://github.com/snakers4/silero-vad. - vad_parameters: Dictionary of Silero VAD parameters (see available parameters and - default values in the function `get_speech_timestamps`). + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). Returns: A tuple with: - a generator over transcribed segments - - an instance of AudioInfo + - an instance of TranscriptionInfo """ sampling_rate = self.feature_extractor.sampling_rate @@ -229,8 +246,11 @@ class WhisperModel: ) if vad_filter: - vad_parameters = {} if vad_parameters is None else vad_parameters - speech_chunks = get_speech_timestamps(audio, **vad_parameters) + if vad_parameters is None: + vad_parameters = VadOptions() + elif isinstance(vad_parameters, dict): + vad_parameters = VadOptions(**vad_parameters) + speech_chunks = get_speech_timestamps(audio, vad_parameters) audio = collect_chunks(audio, speech_chunks) self.logger.info( @@ -257,6 +277,7 @@ class WhisperModel: features = self.feature_extractor(audio) encoder_output = None + all_language_probs = None if language is None: if not self.model.is_multilingual: @@ -265,9 +286,13 @@ class WhisperModel: else: segment = features[:, : self.feature_extractor.nb_max_frames] encoder_output = self.encode(segment) - results = self.model.detect_language(encoder_output) - language_token, language_probability = results[0][0] - language = language_token[2:-2] + # results is a list of tuple[str, float] with language names and + # probabilities. + results = self.model.detect_language(encoder_output)[0] + # Parse language names to strip out markers + all_language_probs = [(token[2:-2], prob) for (token, prob) in results] + # Get top language token and probability + language, language_probability = all_language_probs[0] self.logger.info( "Detected language '%s' with probability %.2f", @@ -312,13 +337,16 @@ class WhisperModel: if speech_chunks: segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate) - audio_info = AudioInfo( + info = TranscriptionInfo( language=language, language_probability=language_probability, duration=duration, + transcription_options=options, + vad_options=vad_parameters, + all_language_probs=all_language_probs, ) - return segments, audio_info + return segments, info def generate_segments( self, @@ -328,15 +356,19 @@ class WhisperModel: encoder_output: Optional[ctranslate2.StorageView] = None, ) -> Iterable[Segment]: content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames + idx = 0 seek = 0 all_tokens = [] all_prompt_text = [] prompt_reset_since = 0 if options.initial_prompt is not None: - initial_prompt = " " + options.initial_prompt.strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) - all_tokens.extend(initial_prompt_tokens) + if isinstance(options.initial_prompt, str): + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + else: + all_tokens.extend(options.initial_prompt) while seek < content_frames: time_offset = seek * self.feature_extractor.time_per_frame @@ -362,9 +394,12 @@ class WhisperModel: if encoder_output is None: encoder_output = self.encode(segment) - result, avg_log_prob, temperature = self.generate_with_fallback( - encoder_output, prompt, tokenizer, options - ) + ( + result, + avg_logprob, + temperature, + compression_ratio, + ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options) if options.no_speech_threshold is not None: # no voice activity check @@ -372,7 +407,7 @@ class WhisperModel: if ( options.log_prob_threshold is not None - and avg_log_prob > options.log_prob_threshold + and avg_logprob > options.log_prob_threshold ): # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False @@ -468,9 +503,6 @@ class WhisperModel: seek += segment_size - if not options.condition_on_previous_text or temperature > 0.5: - prompt_reset_since = len(all_tokens) - if options.word_timestamps: self.add_word_timestamps( current_segments, @@ -511,20 +543,29 @@ class WhisperModel: ): all_tokens.extend(tokens) all_prompt_text.append(text) + idx += 1 yield Segment( + id=idx, + seek=seek, start=segment["start"], end=segment["end"], text=text, + tokens=tokens, + temperature=temperature, + avg_logprob=avg_logprob, + compression_ratio=compression_ratio, + no_speech_prob=result.no_speech_prob, words=( [Word(**word) for word in segment["words"]] if options.word_timestamps else None ), - avg_log_prob=avg_log_prob, - no_speech_prob=result.no_speech_prob, ) + if not options.condition_on_previous_text or temperature > 0.5: + prompt_reset_since = len(all_tokens) + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. @@ -541,10 +582,11 @@ class WhisperModel: prompt: List[int], tokenizer: Tokenizer, options: TranscriptionOptions, - ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]: + ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: result = None - avg_log_prob = None + avg_logprob = None final_temperature = None + compression_ratio = None max_initial_timestamp_index = int( round(options.max_initial_timestamp / self.time_precision) @@ -582,8 +624,8 @@ class WhisperModel: # Recover the average log prob from the returned score. seq_len = len(tokens) - cum_log_prob = result.scores[0] * (seq_len**options.length_penalty) - avg_log_prob = cum_log_prob / (seq_len + 1) + cum_logprob = result.scores[0] * (seq_len**options.length_penalty) + avg_logprob = cum_logprob / (seq_len + 1) text = tokenizer.decode(tokens).strip() compression_ratio = get_compression_ratio(text) @@ -605,21 +647,21 @@ class WhisperModel: if ( options.log_prob_threshold is not None - and avg_log_prob < options.log_prob_threshold + and avg_logprob < options.log_prob_threshold ): needs_fallback = True # average log probability is too low self.logger.debug( "Log probability threshold is not met with temperature %.1f (%f < %f)", temperature, - avg_log_prob, + avg_logprob, options.log_prob_threshold, ) if not needs_fallback: break - return result, avg_log_prob, final_temperature + return result, avg_logprob, final_temperature, compression_ratio def get_prompt( self, @@ -734,6 +776,8 @@ class WhisperModel: text_tokens + [tokenizer.eot] ) word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) + if len(word_boundaries) <= 1: + return [] jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) jump_times = time_indices[jumps] / self.tokens_per_second @@ -782,7 +826,8 @@ def restore_speech_timestamps( words = [] for word in segment.words: # Ensure the word start and end times are resolved to the same chunk. - chunk_index = ts_map.get_chunk_index(word.start) + middle = (word.start + word.end) / 2 + chunk_index = ts_map.get_chunk_index(middle) word = word._replace( start=ts_map.get_original_time(word.start, chunk_index), end=ts_map.get_original_time(word.end, chunk_index), diff --git a/faster_whisper/utils.py b/faster_whisper/utils.py index 66c7161..950b0da 100644 --- a/faster_whisper/utils.py +++ b/faster_whisper/utils.py @@ -4,6 +4,7 @@ import os from typing import Optional import huggingface_hub +import requests from tqdm.auto import tqdm @@ -31,7 +32,12 @@ def get_logger(): return logging.getLogger("faster_whisper") -def download_model(size: str, output_dir: Optional[str] = None): +def download_model( + size: str, + output_dir: Optional[str] = None, + local_files_only: bool = False, + cache_dir: Optional[str] = None, +): """Downloads a CTranslate2 Whisper model from the Hugging Face Hub. The model is downloaded from https://huggingface.co/guillaumekln. @@ -40,7 +46,10 @@ def download_model(size: str, output_dir: Optional[str] = None): size: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, or large-v2). output_dir: Directory where the model should be saved. If not set, the model is saved in - the standard Hugging Face cache directory. + the cache directory. + local_files_only: If True, avoid downloading the file and return the path to the local + cached file if it exists. + cache_dir: Path to the folder where cached files are stored. Returns: The path to the downloaded model. @@ -54,25 +63,45 @@ def download_model(size: str, output_dir: Optional[str] = None): ) repo_id = "guillaumekln/faster-whisper-%s" % size - kwargs = {} - - if output_dir is not None: - kwargs["local_dir"] = output_dir - kwargs["local_dir_use_symlinks"] = False allow_patterns = [ "config.json", "model.bin", "tokenizer.json", - "vocabulary.txt", + "vocabulary.*", ] - return huggingface_hub.snapshot_download( - repo_id, - allow_patterns=allow_patterns, - tqdm_class=disabled_tqdm, - **kwargs, - ) + kwargs = { + "local_files_only": local_files_only, + "allow_patterns": allow_patterns, + "tqdm_class": disabled_tqdm, + } + + if output_dir is not None: + kwargs["local_dir"] = output_dir + kwargs["local_dir_use_symlinks"] = False + + if cache_dir is not None: + kwargs["cache_dir"] = cache_dir + + try: + return huggingface_hub.snapshot_download(repo_id, **kwargs) + except ( + huggingface_hub.utils.HfHubHTTPError, + requests.exceptions.ConnectionError, + ) as exception: + logger = get_logger() + logger.warning( + "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s", + repo_id, + exception, + ) + logger.warning( + "Trying to load the model directly from the local cache, if it exists." + ) + + kwargs["local_files_only"] = True + return huggingface_hub.snapshot_download(repo_id, **kwargs) def format_timestamp( diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index 080795d..487dfa0 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -3,47 +3,67 @@ import functools import os import warnings -from typing import List, Optional +from typing import List, NamedTuple, Optional import numpy as np from faster_whisper.utils import get_assets_path + # The code below is adapted from https://github.com/snakers4/silero-vad. +class VadOptions(NamedTuple): + """VAD options. - -def get_speech_timestamps( - audio: np.ndarray, - *, - threshold: float = 0.5, - min_speech_duration_ms: int = 250, - max_speech_duration_s: float = float("inf"), - min_silence_duration_ms: int = 2000, - window_size_samples: int = 1024, - speech_pad_ms: int = 200, -) -> List[dict]: - """This method is used for splitting long audios into speech chunks using silero VAD. - - Args: - audio: One dimensional float array. + Attributes: threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that - lasts more than 100s (if any), to prevent agressive cutting. Otherwise, they will be + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be split aggressively just before max_speech_duration_s. min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating it window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate. - Values other than these may affect model perfomance!! + Values other than these may affect model performance!! speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side + """ + + threshold: float = 0.5 + min_speech_duration_ms: int = 250 + max_speech_duration_s: float = float("inf") + min_silence_duration_ms: int = 2000 + window_size_samples: int = 1024 + speech_pad_ms: int = 400 + + +def get_speech_timestamps( + audio: np.ndarray, + vad_options: Optional[VadOptions] = None, + **kwargs, +) -> List[dict]: + """This method is used for splitting long audios into speech chunks using silero VAD. + + Args: + audio: One dimensional float array. + vad_options: Options for VAD processing. + kwargs: VAD options passed as keyword arguments for backward compatibility. Returns: List of dicts containing begin and end samples of each speech chunk. """ + if vad_options is None: + vad_options = VadOptions(**kwargs) + + threshold = vad_options.threshold + min_speech_duration_ms = vad_options.min_speech_duration_ms + max_speech_duration_s = vad_options.max_speech_duration_s + min_silence_duration_ms = vad_options.min_silence_duration_ms + window_size_samples = vad_options.window_size_samples + speech_pad_ms = vad_options.speech_pad_ms + if window_size_samples not in [512, 1024, 1536]: warnings.warn( "Unusual window_size_samples! Supported window_size_samples:\n" diff --git a/faster_whisper/version.py b/faster_whisper/version.py new file mode 100644 index 0000000..bf288f0 --- /dev/null +++ b/faster_whisper/version.py @@ -0,0 +1,3 @@ +"""Version information.""" + +__version__ = "0.6.0" diff --git a/requirements.txt b/requirements.txt index 73c3b6d..4dd8bac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ av==10.* ctranslate2>=3.10,<4 huggingface_hub>=0.13 tokenizers==0.13.* -onnxruntime==1.14.* ; python_version < "3.11" +onnxruntime>=1.14,<2 diff --git a/setup.py b/setup.py index e3245db..1deca3b 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,14 @@ def get_long_description(): return readme_file.read() +def get_project_version(): + version_path = os.path.join(base_dir, "faster_whisper", "version.py") + version = {} + with open(version_path, encoding="utf-8") as fp: + exec(fp.read(), version) + return version["__version__"] + + def get_requirements(path): with open(path, encoding="utf-8") as requirements: return [requirement.strip() for requirement in requirements] @@ -23,7 +31,7 @@ conversion_requires = get_requirements( setup( name="faster-whisper", - version="0.4.1", + version=get_project_version(), license="MIT", description="Faster Whisper transcription with CTranslate2", long_description=get_long_description(), diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 5406535..6ecf2c4 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -6,11 +6,18 @@ from faster_whisper import WhisperModel, decode_audio def test_transcribe(jfk_path): model = WhisperModel("tiny") segments, info = model.transcribe(jfk_path, word_timestamps=True) + assert info.all_language_probs is not None assert info.language == "en" assert info.language_probability > 0.9 assert info.duration == 11 + # Get top language info from all results, which should match the + # already existing metadata + top_lang, top_lang_score = info.all_language_probs[0] + assert info.language == top_lang + assert abs(info.language_probability - top_lang_score) < 1e-16 + segments = list(segments) assert len(segments) == 1 @@ -29,10 +36,10 @@ def test_transcribe(jfk_path): def test_vad(jfk_path): model = WhisperModel("tiny") - segments, _ = model.transcribe( + segments, info = model.transcribe( jfk_path, vad_filter=True, - vad_parameters=dict(min_silence_duration_ms=500), + vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200), ) segments = list(segments) @@ -47,6 +54,9 @@ def test_vad(jfk_path): assert 0 < segment.start < 1 assert 10 < segment.end < 11 + assert info.vad_options.min_silence_duration_ms == 500 + assert info.vad_options.speech_pad_ms == 200 + def test_stereo_diarization(data_dir): model = WhisperModel("tiny") diff --git a/tests/test_utils.py b/tests/test_utils.py index 3e981f6..ee404bf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,3 +15,9 @@ def test_download_model(tmpdir): for filename in os.listdir(model_dir): path = os.path.join(model_dir, filename) assert not os.path.islink(path) + + +def test_download_model_in_cache(tmpdir): + cache_dir = str(tmpdir.join("model")) + download_model("tiny", cache_dir=cache_dir) + assert os.path.isdir(cache_dir)