Merge branch 'master' into prompt

This commit is contained in:
2023-06-24 18:03:05 +08:00
12 changed files with 266 additions and 95 deletions

View File

@@ -15,6 +15,7 @@ from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_logger
from faster_whisper.vad import (
SpeechTimestampsMap,
VadOptions,
collect_chunks,
get_speech_timestamps,
)
@@ -28,18 +29,17 @@ class Word(NamedTuple):
class Segment(NamedTuple):
id: int
seek: int
start: float
end: float
text: str
words: Optional[List[Word]]
avg_log_prob: float
tokens: List[int]
temperature: float
avg_logprob: float
compression_ratio: float
no_speech_prob: float
class AudioInfo(NamedTuple):
language: str
language_probability: float
duration: float
words: Optional[List[Word]]
class TranscriptionOptions(NamedTuple):
@@ -52,7 +52,7 @@ class TranscriptionOptions(NamedTuple):
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
temperatures: List[float]
initial_prompt: Optional[str]
initial_prompt: Optional[Union[str, Iterable[int]]]
prefix: Optional[str]
suppress_blank: bool
suppress_tokens: Optional[List[int]]
@@ -63,6 +63,15 @@ class TranscriptionOptions(NamedTuple):
append_punctuations: str
class TranscriptionInfo(NamedTuple):
language: str
language_probability: float
duration: float
all_language_probs: Optional[List[Tuple[str, float]]]
transcription_options: TranscriptionOptions
vad_options: VadOptions
class WhisperModel:
def __init__(
self,
@@ -73,6 +82,7 @@ class WhisperModel:
cpu_threads: int = 0,
num_workers: int = 1,
download_root: Optional[str] = None,
local_files_only: bool = False,
):
"""Initializes the Whisper model.
@@ -94,15 +104,21 @@ class WhisperModel:
having multiple workers enables true parallelism when running the model
(concurrent calls to self.model.generate() will run in parallel).
This can improve the global throughput at the cost of increased memory usage.
download_root: Directory where the model should be saved. If not set, the model
is saved in the standard Hugging Face cache directory.
download_root: Directory where the models should be saved. If not set, the models
are saved in the standard Hugging Face cache directory.
local_files_only: If True, avoid downloading the file and return the path to the
local cached file if it exists.
"""
self.logger = get_logger()
if os.path.isdir(model_size_or_path):
model_path = model_size_or_path
else:
model_path = download_model(model_size_or_path, download_root)
model_path = download_model(
model_size_or_path,
local_files_only=local_files_only,
cache_dir=download_root,
)
self.model = ctranslate2.models.Whisper(
model_path,
@@ -154,7 +170,7 @@ class WhisperModel:
log_prob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
prefix: Optional[str] = None,
suppress_blank: bool = True,
suppress_tokens: Optional[List[int]] = [-1],
@@ -164,8 +180,8 @@ class WhisperModel:
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
vad_filter: bool = False,
vad_parameters: Optional[dict] = None,
) -> Tuple[Iterable[Segment], AudioInfo]:
vad_parameters: Optional[Union[dict, VadOptions]] = None,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file.
Arguments:
@@ -192,7 +208,8 @@ class WhisperModel:
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync.
initial_prompt: Optional text to provide as a prompt for the first window.
initial_prompt: Optional text string or iterable of token ids to provide as a
prompt for the first window.
prefix: Optional text to provide as a prefix for the first window.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
@@ -208,14 +225,14 @@ class WhisperModel:
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters (see available parameters and
default values in the function `get_speech_timestamps`).
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
Returns:
A tuple with:
- a generator over transcribed segments
- an instance of AudioInfo
- an instance of TranscriptionInfo
"""
sampling_rate = self.feature_extractor.sampling_rate
@@ -229,8 +246,11 @@ class WhisperModel:
)
if vad_filter:
vad_parameters = {} if vad_parameters is None else vad_parameters
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
if vad_parameters is None:
vad_parameters = VadOptions()
elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
self.logger.info(
@@ -257,6 +277,7 @@ class WhisperModel:
features = self.feature_extractor(audio)
encoder_output = None
all_language_probs = None
if language is None:
if not self.model.is_multilingual:
@@ -265,9 +286,13 @@ class WhisperModel:
else:
segment = features[:, : self.feature_extractor.nb_max_frames]
encoder_output = self.encode(segment)
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
# results is a list of tuple[str, float] with language names and
# probabilities.
results = self.model.detect_language(encoder_output)[0]
# Parse language names to strip out markers
all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
# Get top language token and probability
language, language_probability = all_language_probs[0]
self.logger.info(
"Detected language '%s' with probability %.2f",
@@ -312,13 +337,16 @@ class WhisperModel:
if speech_chunks:
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
audio_info = AudioInfo(
info = TranscriptionInfo(
language=language,
language_probability=language_probability,
duration=duration,
transcription_options=options,
vad_options=vad_parameters,
all_language_probs=all_language_probs,
)
return segments, audio_info
return segments, info
def generate_segments(
self,
@@ -328,15 +356,19 @@ class WhisperModel:
encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
idx = 0
seek = 0
all_tokens = []
all_prompt_text = []
prompt_reset_since = 0
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
if isinstance(options.initial_prompt, str):
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
else:
all_tokens.extend(options.initial_prompt)
while seek < content_frames:
time_offset = seek * self.feature_extractor.time_per_frame
@@ -362,9 +394,12 @@ class WhisperModel:
if encoder_output is None:
encoder_output = self.encode(segment)
result, avg_log_prob, temperature = self.generate_with_fallback(
encoder_output, prompt, tokenizer, options
)
(
result,
avg_logprob,
temperature,
compression_ratio,
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
if options.no_speech_threshold is not None:
# no voice activity check
@@ -372,7 +407,7 @@ class WhisperModel:
if (
options.log_prob_threshold is not None
and avg_log_prob > options.log_prob_threshold
and avg_logprob > options.log_prob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
@@ -468,9 +503,6 @@ class WhisperModel:
seek += segment_size
if not options.condition_on_previous_text or temperature > 0.5:
prompt_reset_since = len(all_tokens)
if options.word_timestamps:
self.add_word_timestamps(
current_segments,
@@ -511,20 +543,29 @@ class WhisperModel:
):
all_tokens.extend(tokens)
all_prompt_text.append(text)
idx += 1
yield Segment(
id=idx,
seek=seek,
start=segment["start"],
end=segment["end"],
text=text,
tokens=tokens,
temperature=temperature,
avg_logprob=avg_logprob,
compression_ratio=compression_ratio,
no_speech_prob=result.no_speech_prob,
words=(
[Word(**word) for word in segment["words"]]
if options.word_timestamps
else None
),
avg_log_prob=avg_log_prob,
no_speech_prob=result.no_speech_prob,
)
if not options.condition_on_previous_text or temperature > 0.5:
prompt_reset_since = len(all_tokens)
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
@@ -541,10 +582,11 @@ class WhisperModel:
prompt: List[int],
tokenizer: Tokenizer,
options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
result = None
avg_log_prob = None
avg_logprob = None
final_temperature = None
compression_ratio = None
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
@@ -582,8 +624,8 @@ class WhisperModel:
# Recover the average log prob from the returned score.
seq_len = len(tokens)
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
avg_log_prob = cum_log_prob / (seq_len + 1)
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
avg_logprob = cum_logprob / (seq_len + 1)
text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text)
@@ -605,21 +647,21 @@ class WhisperModel:
if (
options.log_prob_threshold is not None
and avg_log_prob < options.log_prob_threshold
and avg_logprob < options.log_prob_threshold
):
needs_fallback = True # average log probability is too low
self.logger.debug(
"Log probability threshold is not met with temperature %.1f (%f < %f)",
temperature,
avg_log_prob,
avg_logprob,
options.log_prob_threshold,
)
if not needs_fallback:
break
return result, avg_log_prob, final_temperature
return result, avg_logprob, final_temperature, compression_ratio
def get_prompt(
self,
@@ -734,6 +776,8 @@ class WhisperModel:
text_tokens + [tokenizer.eot]
)
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
if len(word_boundaries) <= 1:
return []
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / self.tokens_per_second
@@ -782,7 +826,8 @@ def restore_speech_timestamps(
words = []
for word in segment.words:
# Ensure the word start and end times are resolved to the same chunk.
chunk_index = ts_map.get_chunk_index(word.start)
middle = (word.start + word.end) / 2
chunk_index = ts_map.get_chunk_index(middle)
word = word._replace(
start=ts_map.get_original_time(word.start, chunk_index),
end=ts_map.get_original_time(word.end, chunk_index),