Compare commits
10 Commits
36160c1e7e
...
9a646b69e6
| Author | SHA1 | Date | |
|---|---|---|---|
|
9a646b69e6
|
|||
|
49af9564ab
|
|||
|
|
3adcc12d0f | ||
|
|
2b53dee6b6 | ||
|
|
06d24056e9 | ||
|
|
e9a082dcf2 | ||
|
|
051b3350e5 | ||
|
|
746f2698db | ||
|
|
a5d03e55fa | ||
|
|
9fa1989073 |
14
.gitignore
vendored
14
.gitignore
vendored
@@ -1 +1,15 @@
|
|||||||
|
# Byte-compiled / Optimized / DLL Files
|
||||||
*.pyc
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
__pycache__/
|
||||||
|
|
||||||
|
# Distribution / Packaging
|
||||||
|
venv/
|
||||||
|
|
||||||
|
# Unit Test
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Ignore IDE, Editor Files
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
|||||||
@@ -87,6 +87,13 @@ for segment in segments:
|
|||||||
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Warning:** `segments` is a *generator* so the transcription only starts when you iterate over it. The transcription can be run to completion by gathering the segments in a list or a `for` loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
segments, _ = model.transcribe("audio.mp3")
|
||||||
|
segments = list(segments) # The transcription will actually run here.
|
||||||
|
```
|
||||||
|
|
||||||
#### Word-level timestamps
|
#### Word-level timestamps
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
@@ -125,19 +125,21 @@ class Tokenizer:
|
|||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
decoded = self.decode_with_timestamps(current_tokens)
|
decoded = self.decode_with_timestamps(current_tokens)
|
||||||
|
|
||||||
if (
|
try:
|
||||||
replacement_char not in decoded
|
replacement_char_index = decoded.index(replacement_char)
|
||||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
replacement_char_index += unicode_offset
|
||||||
== replacement_char
|
except ValueError:
|
||||||
|
replacement_char_index = None
|
||||||
|
|
||||||
|
if replacement_char_index is None or (
|
||||||
|
replacement_char_index < len(decoded_full)
|
||||||
|
and decoded_full[replacement_char_index] == replacement_char
|
||||||
):
|
):
|
||||||
words.append(decoded)
|
words.append(decoded)
|
||||||
word_tokens.append(current_tokens)
|
word_tokens.append(current_tokens)
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
unicode_offset += len(decoded)
|
unicode_offset += len(decoded)
|
||||||
|
|
||||||
if unicode_offset >= len(decoded_full):
|
|
||||||
break
|
|
||||||
|
|
||||||
return words, word_tokens
|
return words, word_tokens
|
||||||
|
|
||||||
def split_tokens_on_spaces(
|
def split_tokens_on_spaces(
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import zlib
|
import zlib
|
||||||
|
|
||||||
@@ -11,7 +12,7 @@ import tokenizers
|
|||||||
from faster_whisper.audio import decode_audio
|
from faster_whisper.audio import decode_audio
|
||||||
from faster_whisper.feature_extractor import FeatureExtractor
|
from faster_whisper.feature_extractor import FeatureExtractor
|
||||||
from faster_whisper.tokenizer import Tokenizer
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
from faster_whisper.utils import download_model
|
from faster_whisper.utils import download_model, format_timestamp, get_logger
|
||||||
from faster_whisper.vad import (
|
from faster_whisper.vad import (
|
||||||
SpeechTimestampsMap,
|
SpeechTimestampsMap,
|
||||||
collect_chunks,
|
collect_chunks,
|
||||||
@@ -71,6 +72,7 @@ class WhisperModel:
|
|||||||
compute_type: str = "default",
|
compute_type: str = "default",
|
||||||
cpu_threads: int = 0,
|
cpu_threads: int = 0,
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
|
download_root: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initializes the Whisper model.
|
"""Initializes the Whisper model.
|
||||||
|
|
||||||
@@ -92,11 +94,15 @@ class WhisperModel:
|
|||||||
having multiple workers enables true parallelism when running the model
|
having multiple workers enables true parallelism when running the model
|
||||||
(concurrent calls to self.model.generate() will run in parallel).
|
(concurrent calls to self.model.generate() will run in parallel).
|
||||||
This can improve the global throughput at the cost of increased memory usage.
|
This can improve the global throughput at the cost of increased memory usage.
|
||||||
|
download_root: Directory where the model should be saved. If not set, the model
|
||||||
|
is saved in the standard Hugging Face cache directory.
|
||||||
"""
|
"""
|
||||||
|
self.logger = get_logger()
|
||||||
|
|
||||||
if os.path.isdir(model_size_or_path):
|
if os.path.isdir(model_size_or_path):
|
||||||
model_path = model_size_or_path
|
model_path = model_size_or_path
|
||||||
else:
|
else:
|
||||||
model_path = download_model(model_size_or_path)
|
model_path = download_model(model_size_or_path, download_root)
|
||||||
|
|
||||||
self.model = ctranslate2.models.Whisper(
|
self.model = ctranslate2.models.Whisper(
|
||||||
model_path,
|
model_path,
|
||||||
@@ -211,17 +217,40 @@ class WhisperModel:
|
|||||||
- a generator over transcribed segments
|
- a generator over transcribed segments
|
||||||
- an instance of AudioInfo
|
- an instance of AudioInfo
|
||||||
"""
|
"""
|
||||||
if not isinstance(audio, np.ndarray):
|
sampling_rate = self.feature_extractor.sampling_rate
|
||||||
audio = decode_audio(
|
|
||||||
audio, sampling_rate=self.feature_extractor.sampling_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = audio.shape[0] / self.feature_extractor.sampling_rate
|
if not isinstance(audio, np.ndarray):
|
||||||
|
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
||||||
|
|
||||||
|
duration = audio.shape[0] / sampling_rate
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Processing audio with duration %s", format_timestamp(duration)
|
||||||
|
)
|
||||||
|
|
||||||
if vad_filter:
|
if vad_filter:
|
||||||
vad_parameters = {} if vad_parameters is None else vad_parameters
|
vad_parameters = {} if vad_parameters is None else vad_parameters
|
||||||
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
|
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
|
||||||
audio = collect_chunks(audio, speech_chunks)
|
audio = collect_chunks(audio, speech_chunks)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"VAD filter removed %s of audio",
|
||||||
|
format_timestamp(duration - (audio.shape[0] / sampling_rate)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.logger.isEnabledFor(logging.DEBUG):
|
||||||
|
self.logger.debug(
|
||||||
|
"VAD filter kept the following audio segments: %s",
|
||||||
|
", ".join(
|
||||||
|
"[%s -> %s]"
|
||||||
|
% (
|
||||||
|
format_timestamp(chunk["start"] / sampling_rate),
|
||||||
|
format_timestamp(chunk["end"] / sampling_rate),
|
||||||
|
)
|
||||||
|
for chunk in speech_chunks
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
speech_chunks = None
|
speech_chunks = None
|
||||||
|
|
||||||
@@ -239,6 +268,12 @@ class WhisperModel:
|
|||||||
results = self.model.detect_language(encoder_output)
|
results = self.model.detect_language(encoder_output)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Detected language '%s' with probability %.2f",
|
||||||
|
language,
|
||||||
|
language_probability,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
language_probability = 1
|
language_probability = 1
|
||||||
|
|
||||||
@@ -275,9 +310,7 @@ class WhisperModel:
|
|||||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||||
|
|
||||||
if speech_chunks:
|
if speech_chunks:
|
||||||
segments = restore_speech_timestamps(
|
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
|
||||||
segments, speech_chunks, self.feature_extractor.sampling_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_info = AudioInfo(
|
audio_info = AudioInfo(
|
||||||
language=language,
|
language=language,
|
||||||
@@ -297,6 +330,7 @@ class WhisperModel:
|
|||||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||||
seek = 0
|
seek = 0
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
|
all_prompt_text = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
if options.initial_prompt is not None:
|
if options.initial_prompt is not None:
|
||||||
@@ -312,6 +346,11 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
||||||
|
|
||||||
|
if self.logger.isEnabledFor(logging.DEBUG):
|
||||||
|
self.logger.debug(
|
||||||
|
"Processing segment at %s", format_timestamp(time_offset)
|
||||||
|
)
|
||||||
|
|
||||||
previous_tokens = all_tokens[prompt_reset_since:]
|
previous_tokens = all_tokens[prompt_reset_since:]
|
||||||
prompt = self.get_prompt(
|
prompt = self.get_prompt(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -339,6 +378,12 @@ class WhisperModel:
|
|||||||
should_skip = False
|
should_skip = False
|
||||||
|
|
||||||
if should_skip:
|
if should_skip:
|
||||||
|
self.logger.debug(
|
||||||
|
"No speech threshold is met (%f > %f)",
|
||||||
|
result.no_speech_prob,
|
||||||
|
options.no_speech_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
# fast-forward to the next segment boundary
|
# fast-forward to the next segment boundary
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
continue
|
continue
|
||||||
@@ -457,7 +502,15 @@ class WhisperModel:
|
|||||||
if segment["start"] == segment["end"] or not text.strip():
|
if segment["start"] == segment["end"] or not text.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
all_tokens.extend(tokens)
|
check_prompt_num = 1
|
||||||
|
if all(
|
||||||
|
[
|
||||||
|
text.strip() != i.strip()
|
||||||
|
for i in all_prompt_text[-check_prompt_num:]
|
||||||
|
]
|
||||||
|
):
|
||||||
|
all_tokens.extend(tokens)
|
||||||
|
all_prompt_text.append(text)
|
||||||
|
|
||||||
yield Segment(
|
yield Segment(
|
||||||
start=segment["start"],
|
start=segment["start"],
|
||||||
@@ -543,12 +596,26 @@ class WhisperModel:
|
|||||||
):
|
):
|
||||||
needs_fallback = True # too repetitive
|
needs_fallback = True # too repetitive
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
|
||||||
|
temperature,
|
||||||
|
compression_ratio,
|
||||||
|
options.compression_ratio_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
options.log_prob_threshold is not None
|
options.log_prob_threshold is not None
|
||||||
and avg_log_prob < options.log_prob_threshold
|
and avg_log_prob < options.log_prob_threshold
|
||||||
):
|
):
|
||||||
needs_fallback = True # average log probability is too low
|
needs_fallback = True # average log probability is too low
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
||||||
|
temperature,
|
||||||
|
avg_log_prob,
|
||||||
|
options.log_prob_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -721,14 +788,18 @@ def restore_speech_timestamps(
|
|||||||
end=ts_map.get_original_time(word.end, chunk_index),
|
end=ts_map.get_original_time(word.end, chunk_index),
|
||||||
)
|
)
|
||||||
words.append(word)
|
words.append(word)
|
||||||
else:
|
|
||||||
words = segment.words
|
|
||||||
|
|
||||||
segment = segment._replace(
|
segment = segment._replace(
|
||||||
start=ts_map.get_original_time(segment.start),
|
start=words[0].start,
|
||||||
end=ts_map.get_original_time(segment.end),
|
end=words[-1].end,
|
||||||
words=words,
|
words=words,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
segment = segment._replace(
|
||||||
|
start=ts_map.get_original_time(segment.start),
|
||||||
|
end=ts_map.get_original_time(segment.end),
|
||||||
|
)
|
||||||
|
|
||||||
yield segment
|
yield segment
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -25,6 +26,11 @@ def get_assets_path():
|
|||||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger():
|
||||||
|
"""Returns the module logger."""
|
||||||
|
return logging.getLogger("faster_whisper")
|
||||||
|
|
||||||
|
|
||||||
def download_model(size: str, output_dir: Optional[str] = None):
|
def download_model(size: str, output_dir: Optional[str] = None):
|
||||||
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -23,7 +23,7 @@ conversion_requires = get_requirements(
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="faster-whisper",
|
name="faster-whisper",
|
||||||
version="0.4.0",
|
version="0.4.1",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
description="Faster Whisper transcription with CTranslate2",
|
description="Faster Whisper transcription with CTranslate2",
|
||||||
long_description=get_long_description(),
|
long_description=get_long_description(),
|
||||||
|
|||||||
Reference in New Issue
Block a user