Merge branch 'master' into prompt
This commit is contained in:
@@ -15,6 +15,7 @@ from faster_whisper.tokenizer import Tokenizer
|
||||
from faster_whisper.utils import download_model, format_timestamp, get_logger
|
||||
from faster_whisper.vad import (
|
||||
SpeechTimestampsMap,
|
||||
VadOptions,
|
||||
collect_chunks,
|
||||
get_speech_timestamps,
|
||||
)
|
||||
@@ -28,18 +29,17 @@ class Word(NamedTuple):
|
||||
|
||||
|
||||
class Segment(NamedTuple):
|
||||
id: int
|
||||
seek: int
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: Optional[List[Word]]
|
||||
avg_log_prob: float
|
||||
tokens: List[int]
|
||||
temperature: float
|
||||
avg_logprob: float
|
||||
compression_ratio: float
|
||||
no_speech_prob: float
|
||||
|
||||
|
||||
class AudioInfo(NamedTuple):
|
||||
language: str
|
||||
language_probability: float
|
||||
duration: float
|
||||
words: Optional[List[Word]]
|
||||
|
||||
|
||||
class TranscriptionOptions(NamedTuple):
|
||||
@@ -52,7 +52,7 @@ class TranscriptionOptions(NamedTuple):
|
||||
compression_ratio_threshold: Optional[float]
|
||||
condition_on_previous_text: bool
|
||||
temperatures: List[float]
|
||||
initial_prompt: Optional[str]
|
||||
initial_prompt: Optional[Union[str, Iterable[int]]]
|
||||
prefix: Optional[str]
|
||||
suppress_blank: bool
|
||||
suppress_tokens: Optional[List[int]]
|
||||
@@ -63,6 +63,15 @@ class TranscriptionOptions(NamedTuple):
|
||||
append_punctuations: str
|
||||
|
||||
|
||||
class TranscriptionInfo(NamedTuple):
|
||||
language: str
|
||||
language_probability: float
|
||||
duration: float
|
||||
all_language_probs: Optional[List[Tuple[str, float]]]
|
||||
transcription_options: TranscriptionOptions
|
||||
vad_options: VadOptions
|
||||
|
||||
|
||||
class WhisperModel:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -73,6 +82,7 @@ class WhisperModel:
|
||||
cpu_threads: int = 0,
|
||||
num_workers: int = 1,
|
||||
download_root: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
):
|
||||
"""Initializes the Whisper model.
|
||||
|
||||
@@ -94,15 +104,21 @@ class WhisperModel:
|
||||
having multiple workers enables true parallelism when running the model
|
||||
(concurrent calls to self.model.generate() will run in parallel).
|
||||
This can improve the global throughput at the cost of increased memory usage.
|
||||
download_root: Directory where the model should be saved. If not set, the model
|
||||
is saved in the standard Hugging Face cache directory.
|
||||
download_root: Directory where the models should be saved. If not set, the models
|
||||
are saved in the standard Hugging Face cache directory.
|
||||
local_files_only: If True, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
"""
|
||||
self.logger = get_logger()
|
||||
|
||||
if os.path.isdir(model_size_or_path):
|
||||
model_path = model_size_or_path
|
||||
else:
|
||||
model_path = download_model(model_size_or_path, download_root)
|
||||
model_path = download_model(
|
||||
model_size_or_path,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=download_root,
|
||||
)
|
||||
|
||||
self.model = ctranslate2.models.Whisper(
|
||||
model_path,
|
||||
@@ -154,7 +170,7 @@ class WhisperModel:
|
||||
log_prob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
suppress_blank: bool = True,
|
||||
suppress_tokens: Optional[List[int]] = [-1],
|
||||
@@ -164,8 +180,8 @@ class WhisperModel:
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
vad_filter: bool = False,
|
||||
vad_parameters: Optional[dict] = None,
|
||||
) -> Tuple[Iterable[Segment], AudioInfo]:
|
||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||
"""Transcribes an input file.
|
||||
|
||||
Arguments:
|
||||
@@ -192,7 +208,8 @@ class WhisperModel:
|
||||
as a prompt for the next window; disabling may make the text inconsistent across
|
||||
windows, but the model becomes less prone to getting stuck in a failure loop,
|
||||
such as repetition looping or timestamps going out of sync.
|
||||
initial_prompt: Optional text to provide as a prompt for the first window.
|
||||
initial_prompt: Optional text string or iterable of token ids to provide as a
|
||||
prompt for the first window.
|
||||
prefix: Optional text to provide as a prefix for the first window.
|
||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||
@@ -208,14 +225,14 @@ class WhisperModel:
|
||||
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||
without speech. This step is using the Silero VAD model
|
||||
https://github.com/snakers4/silero-vad.
|
||||
vad_parameters: Dictionary of Silero VAD parameters (see available parameters and
|
||||
default values in the function `get_speech_timestamps`).
|
||||
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
||||
parameters and default values in the class `VadOptions`).
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
|
||||
- a generator over transcribed segments
|
||||
- an instance of AudioInfo
|
||||
- an instance of TranscriptionInfo
|
||||
"""
|
||||
sampling_rate = self.feature_extractor.sampling_rate
|
||||
|
||||
@@ -229,8 +246,11 @@ class WhisperModel:
|
||||
)
|
||||
|
||||
if vad_filter:
|
||||
vad_parameters = {} if vad_parameters is None else vad_parameters
|
||||
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
|
||||
if vad_parameters is None:
|
||||
vad_parameters = VadOptions()
|
||||
elif isinstance(vad_parameters, dict):
|
||||
vad_parameters = VadOptions(**vad_parameters)
|
||||
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
||||
audio = collect_chunks(audio, speech_chunks)
|
||||
|
||||
self.logger.info(
|
||||
@@ -257,6 +277,7 @@ class WhisperModel:
|
||||
features = self.feature_extractor(audio)
|
||||
|
||||
encoder_output = None
|
||||
all_language_probs = None
|
||||
|
||||
if language is None:
|
||||
if not self.model.is_multilingual:
|
||||
@@ -265,9 +286,13 @@ class WhisperModel:
|
||||
else:
|
||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||
encoder_output = self.encode(segment)
|
||||
results = self.model.detect_language(encoder_output)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
# results is a list of tuple[str, float] with language names and
|
||||
# probabilities.
|
||||
results = self.model.detect_language(encoder_output)[0]
|
||||
# Parse language names to strip out markers
|
||||
all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
|
||||
# Get top language token and probability
|
||||
language, language_probability = all_language_probs[0]
|
||||
|
||||
self.logger.info(
|
||||
"Detected language '%s' with probability %.2f",
|
||||
@@ -312,13 +337,16 @@ class WhisperModel:
|
||||
if speech_chunks:
|
||||
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
|
||||
|
||||
audio_info = AudioInfo(
|
||||
info = TranscriptionInfo(
|
||||
language=language,
|
||||
language_probability=language_probability,
|
||||
duration=duration,
|
||||
transcription_options=options,
|
||||
vad_options=vad_parameters,
|
||||
all_language_probs=all_language_probs,
|
||||
)
|
||||
|
||||
return segments, audio_info
|
||||
return segments, info
|
||||
|
||||
def generate_segments(
|
||||
self,
|
||||
@@ -328,15 +356,19 @@ class WhisperModel:
|
||||
encoder_output: Optional[ctranslate2.StorageView] = None,
|
||||
) -> Iterable[Segment]:
|
||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
idx = 0
|
||||
seek = 0
|
||||
all_tokens = []
|
||||
all_prompt_text = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
if options.initial_prompt is not None:
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
if isinstance(options.initial_prompt, str):
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
else:
|
||||
all_tokens.extend(options.initial_prompt)
|
||||
|
||||
while seek < content_frames:
|
||||
time_offset = seek * self.feature_extractor.time_per_frame
|
||||
@@ -362,9 +394,12 @@ class WhisperModel:
|
||||
if encoder_output is None:
|
||||
encoder_output = self.encode(segment)
|
||||
|
||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||
encoder_output, prompt, tokenizer, options
|
||||
)
|
||||
(
|
||||
result,
|
||||
avg_logprob,
|
||||
temperature,
|
||||
compression_ratio,
|
||||
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
|
||||
|
||||
if options.no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
@@ -372,7 +407,7 @@ class WhisperModel:
|
||||
|
||||
if (
|
||||
options.log_prob_threshold is not None
|
||||
and avg_log_prob > options.log_prob_threshold
|
||||
and avg_logprob > options.log_prob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
@@ -468,9 +503,6 @@ class WhisperModel:
|
||||
|
||||
seek += segment_size
|
||||
|
||||
if not options.condition_on_previous_text or temperature > 0.5:
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
if options.word_timestamps:
|
||||
self.add_word_timestamps(
|
||||
current_segments,
|
||||
@@ -511,20 +543,29 @@ class WhisperModel:
|
||||
):
|
||||
all_tokens.extend(tokens)
|
||||
all_prompt_text.append(text)
|
||||
idx += 1
|
||||
|
||||
yield Segment(
|
||||
id=idx,
|
||||
seek=seek,
|
||||
start=segment["start"],
|
||||
end=segment["end"],
|
||||
text=text,
|
||||
tokens=tokens,
|
||||
temperature=temperature,
|
||||
avg_logprob=avg_logprob,
|
||||
compression_ratio=compression_ratio,
|
||||
no_speech_prob=result.no_speech_prob,
|
||||
words=(
|
||||
[Word(**word) for word in segment["words"]]
|
||||
if options.word_timestamps
|
||||
else None
|
||||
),
|
||||
avg_log_prob=avg_log_prob,
|
||||
no_speech_prob=result.no_speech_prob,
|
||||
)
|
||||
|
||||
if not options.condition_on_previous_text or temperature > 0.5:
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
# When the model is running on multiple GPUs, the encoder output should be moved
|
||||
# to the CPU since we don't know which GPU will handle the next job.
|
||||
@@ -541,10 +582,11 @@ class WhisperModel:
|
||||
prompt: List[int],
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
||||
result = None
|
||||
avg_log_prob = None
|
||||
avg_logprob = None
|
||||
final_temperature = None
|
||||
compression_ratio = None
|
||||
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
@@ -582,8 +624,8 @@ class WhisperModel:
|
||||
|
||||
# Recover the average log prob from the returned score.
|
||||
seq_len = len(tokens)
|
||||
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
|
||||
avg_log_prob = cum_log_prob / (seq_len + 1)
|
||||
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
||||
avg_logprob = cum_logprob / (seq_len + 1)
|
||||
|
||||
text = tokenizer.decode(tokens).strip()
|
||||
compression_ratio = get_compression_ratio(text)
|
||||
@@ -605,21 +647,21 @@ class WhisperModel:
|
||||
|
||||
if (
|
||||
options.log_prob_threshold is not None
|
||||
and avg_log_prob < options.log_prob_threshold
|
||||
and avg_logprob < options.log_prob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
|
||||
self.logger.debug(
|
||||
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
||||
temperature,
|
||||
avg_log_prob,
|
||||
avg_logprob,
|
||||
options.log_prob_threshold,
|
||||
)
|
||||
|
||||
if not needs_fallback:
|
||||
break
|
||||
|
||||
return result, avg_log_prob, final_temperature
|
||||
return result, avg_logprob, final_temperature, compression_ratio
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
@@ -734,6 +776,8 @@ class WhisperModel:
|
||||
text_tokens + [tokenizer.eot]
|
||||
)
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
if len(word_boundaries) <= 1:
|
||||
return []
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] / self.tokens_per_second
|
||||
@@ -782,7 +826,8 @@ def restore_speech_timestamps(
|
||||
words = []
|
||||
for word in segment.words:
|
||||
# Ensure the word start and end times are resolved to the same chunk.
|
||||
chunk_index = ts_map.get_chunk_index(word.start)
|
||||
middle = (word.start + word.end) / 2
|
||||
chunk_index = ts_map.get_chunk_index(middle)
|
||||
word = word._replace(
|
||||
start=ts_map.get_original_time(word.start, chunk_index),
|
||||
end=ts_map.get_original_time(word.end, chunk_index),
|
||||
|
||||
Reference in New Issue
Block a user