Compare commits

...

10 Commits

Author SHA1 Message Date
9a646b69e6 format code 2023-04-20 02:00:57 +08:00
49af9564ab Ignore repeated prompt 2023-04-20 01:49:10 +08:00
Guillaume Klein
3adcc12d0f Clarify that the returned segments value is a generator (#144)
* Clarify that the returned segments value is a generator

* Update README.md
2023-04-13 09:50:53 +02:00
Ewald Enzinger
2b53dee6b6 Expose download location in WhisperModel constructor (#126)
This increases compatibility with OpenAI Whisper's whisper.load_model() and is useful for downstream integrations
2023-04-08 10:02:36 +02:00
Bekir Bakar
06d24056e9 Configure ignore for more files. (#122) 2023-04-06 19:13:09 +02:00
Guillaume Klein
e9a082dcf2 Keep segment timestamps aligned with words timestamps after VAD (#119) 2023-04-06 11:54:40 +02:00
Guillaume Klein
051b3350e5 Add some info and debug logs (#113) 2023-04-05 16:57:59 +02:00
Guillaume Klein
746f2698db Bump version to 0.4.1 2023-04-04 12:16:23 +02:00
Guillaume Klein
a5d03e55fa Prevent out of range error in method split_tokens_on_unicode (#111) 2023-04-04 10:51:14 +02:00
Guillaume Klein
9fa1989073 Revert "Prevent out of range error in method split_tokens_on_unicode"
This reverts commit 36160c1e7e.
2023-04-04 10:25:41 +02:00
6 changed files with 126 additions and 26 deletions

14
.gitignore vendored
View File

@@ -1 +1,15 @@
# Byte-compiled / Optimized / DLL Files
*.pyc
*.pyo
*.pyd
__pycache__/
# Distribution / Packaging
venv/
# Unit Test
.pytest_cache/
# Ignore IDE, Editor Files
.idea/
.vscode/

View File

@@ -87,6 +87,13 @@ for segment in segments:
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
```
**Warning:** `segments` is a *generator* so the transcription only starts when you iterate over it. The transcription can be run to completion by gathering the segments in a list or a `for` loop:
```python
segments, _ = model.transcribe("audio.mp3")
segments = list(segments) # The transcription will actually run here.
```
#### Word-level timestamps
```python

View File

@@ -125,19 +125,21 @@ class Tokenizer:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
try:
replacement_char_index = decoded.index(replacement_char)
replacement_char_index += unicode_offset
except ValueError:
replacement_char_index = None
if replacement_char_index is None or (
replacement_char_index < len(decoded_full)
and decoded_full[replacement_char_index] == replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
if unicode_offset >= len(decoded_full):
break
return words, word_tokens
def split_tokens_on_spaces(

View File

@@ -1,4 +1,5 @@
import itertools
import logging
import os
import zlib
@@ -11,7 +12,7 @@ import tokenizers
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model
from faster_whisper.utils import download_model, format_timestamp, get_logger
from faster_whisper.vad import (
SpeechTimestampsMap,
collect_chunks,
@@ -71,6 +72,7 @@ class WhisperModel:
compute_type: str = "default",
cpu_threads: int = 0,
num_workers: int = 1,
download_root: Optional[str] = None,
):
"""Initializes the Whisper model.
@@ -92,11 +94,15 @@ class WhisperModel:
having multiple workers enables true parallelism when running the model
(concurrent calls to self.model.generate() will run in parallel).
This can improve the global throughput at the cost of increased memory usage.
download_root: Directory where the model should be saved. If not set, the model
is saved in the standard Hugging Face cache directory.
"""
self.logger = get_logger()
if os.path.isdir(model_size_or_path):
model_path = model_size_or_path
else:
model_path = download_model(model_size_or_path)
model_path = download_model(model_size_or_path, download_root)
self.model = ctranslate2.models.Whisper(
model_path,
@@ -211,17 +217,40 @@ class WhisperModel:
- a generator over transcribed segments
- an instance of AudioInfo
"""
if not isinstance(audio, np.ndarray):
audio = decode_audio(
audio, sampling_rate=self.feature_extractor.sampling_rate
)
sampling_rate = self.feature_extractor.sampling_rate
duration = audio.shape[0] / self.feature_extractor.sampling_rate
if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate
self.logger.info(
"Processing audio with duration %s", format_timestamp(duration)
)
if vad_filter:
vad_parameters = {} if vad_parameters is None else vad_parameters
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
audio = collect_chunks(audio, speech_chunks)
self.logger.info(
"VAD filter removed %s of audio",
format_timestamp(duration - (audio.shape[0] / sampling_rate)),
)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"VAD filter kept the following audio segments: %s",
", ".join(
"[%s -> %s]"
% (
format_timestamp(chunk["start"] / sampling_rate),
format_timestamp(chunk["end"] / sampling_rate),
)
for chunk in speech_chunks
),
)
else:
speech_chunks = None
@@ -239,6 +268,12 @@ class WhisperModel:
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
self.logger.info(
"Detected language '%s' with probability %.2f",
language,
language_probability,
)
else:
language_probability = 1
@@ -275,9 +310,7 @@ class WhisperModel:
segments = self.generate_segments(features, tokenizer, options, encoder_output)
if speech_chunks:
segments = restore_speech_timestamps(
segments, speech_chunks, self.feature_extractor.sampling_rate
)
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
audio_info = AudioInfo(
language=language,
@@ -297,6 +330,7 @@ class WhisperModel:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
seek = 0
all_tokens = []
all_prompt_text = []
prompt_reset_since = 0
if options.initial_prompt is not None:
@@ -312,6 +346,11 @@ class WhisperModel:
)
segment_duration = segment_size * self.feature_extractor.time_per_frame
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"Processing segment at %s", format_timestamp(time_offset)
)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
@@ -339,6 +378,12 @@ class WhisperModel:
should_skip = False
if should_skip:
self.logger.debug(
"No speech threshold is met (%f > %f)",
result.no_speech_prob,
options.no_speech_threshold,
)
# fast-forward to the next segment boundary
seek += segment_size
continue
@@ -457,7 +502,15 @@ class WhisperModel:
if segment["start"] == segment["end"] or not text.strip():
continue
all_tokens.extend(tokens)
check_prompt_num = 1
if all(
[
text.strip() != i.strip()
for i in all_prompt_text[-check_prompt_num:]
]
):
all_tokens.extend(tokens)
all_prompt_text.append(text)
yield Segment(
start=segment["start"],
@@ -543,12 +596,26 @@ class WhisperModel:
):
needs_fallback = True # too repetitive
self.logger.debug(
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
temperature,
compression_ratio,
options.compression_ratio_threshold,
)
if (
options.log_prob_threshold is not None
and avg_log_prob < options.log_prob_threshold
):
needs_fallback = True # average log probability is too low
self.logger.debug(
"Log probability threshold is not met with temperature %.1f (%f < %f)",
temperature,
avg_log_prob,
options.log_prob_threshold,
)
if not needs_fallback:
break
@@ -721,14 +788,18 @@ def restore_speech_timestamps(
end=ts_map.get_original_time(word.end, chunk_index),
)
words.append(word)
else:
words = segment.words
segment = segment._replace(
start=ts_map.get_original_time(segment.start),
end=ts_map.get_original_time(segment.end),
words=words,
)
segment = segment._replace(
start=words[0].start,
end=words[-1].end,
words=words,
)
else:
segment = segment._replace(
start=ts_map.get_original_time(segment.start),
end=ts_map.get_original_time(segment.end),
)
yield segment

View File

@@ -1,3 +1,4 @@
import logging
import os
from typing import Optional
@@ -25,6 +26,11 @@ def get_assets_path():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
def get_logger():
"""Returns the module logger."""
return logging.getLogger("faster_whisper")
def download_model(size: str, output_dir: Optional[str] = None):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.

View File

@@ -23,7 +23,7 @@ conversion_requires = get_requirements(
setup(
name="faster-whisper",
version="0.4.0",
version="0.4.1",
license="MIT",
description="Faster Whisper transcription with CTranslate2",
long_description=get_long_description(),