Add clip_timestamps and hallucination_silence_threshold options (#646)
This commit is contained in:
@@ -14,7 +14,7 @@ import tokenizers
|
||||
from faster_whisper.audio import decode_audio
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
|
||||
from faster_whisper.utils import download_model, format_timestamp, get_logger
|
||||
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
|
||||
from faster_whisper.vad import (
|
||||
SpeechTimestampsMap,
|
||||
VadOptions,
|
||||
@@ -67,6 +67,8 @@ class TranscriptionOptions(NamedTuple):
|
||||
prepend_punctuations: str
|
||||
append_punctuations: str
|
||||
max_new_tokens: Optional[int]
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
|
||||
|
||||
class TranscriptionInfo(NamedTuple):
|
||||
@@ -216,6 +218,8 @@ class WhisperModel:
|
||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
chunk_length: Optional[int] = None,
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||
"""Transcribes an input file.
|
||||
|
||||
@@ -271,6 +275,12 @@ class WhisperModel:
|
||||
the maximum will be set by the default max_length.
|
||||
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
||||
default chunk_length of the FeatureExtractor.
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
|
||||
process. The last end timestamp defaults to the end of the file.
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
When word_timestamps is True, skip silent periods longer than this threshold
|
||||
(in seconds) when a possible hallucination is detected
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
@@ -387,6 +397,8 @@ class WhisperModel:
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
max_new_tokens=max_new_tokens,
|
||||
clip_timestamps=clip_timestamps,
|
||||
hallucination_silence_threshold=hallucination_silence_threshold,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||
@@ -414,8 +426,33 @@ class WhisperModel:
|
||||
encoder_output: Optional[ctranslate2.StorageView] = None,
|
||||
) -> Iterable[Segment]:
|
||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
||||
|
||||
if isinstance(options.clip_timestamps, str):
|
||||
TranscriptionOptions.clip_timestamps = [
|
||||
float(ts)
|
||||
for ts in (
|
||||
options.clip_timestamps.split(",")
|
||||
if options.clip_timestamps
|
||||
else []
|
||||
)
|
||||
]
|
||||
seek_points: List[int] = [
|
||||
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
||||
]
|
||||
if len(seek_points) == 0:
|
||||
seek_points.append(0)
|
||||
if len(seek_points) % 2 == 1:
|
||||
seek_points.append(content_frames)
|
||||
seek_clips: List[Tuple[int, int]] = list(
|
||||
zip(seek_points[::2], seek_points[1::2])
|
||||
)
|
||||
|
||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||
|
||||
idx = 0
|
||||
seek = 0
|
||||
clip_idx = 0
|
||||
seek = seek_clips[clip_idx][0]
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
@@ -428,12 +465,30 @@ class WhisperModel:
|
||||
all_tokens.extend(options.initial_prompt)
|
||||
|
||||
last_speech_timestamp = 0.0
|
||||
while seek < content_frames:
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
# A later commit should turn this into a simpler nested loop.
|
||||
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||
# while seek < seek_clip_end
|
||||
while clip_idx < len(seek_clips):
|
||||
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||
if seek < seek_clip_start:
|
||||
seek = seek_clip_start
|
||||
if seek >= seek_clip_end:
|
||||
clip_idx += 1
|
||||
if clip_idx < len(seek_clips):
|
||||
seek = seek_clips[clip_idx][0]
|
||||
continue
|
||||
time_offset = seek * self.feature_extractor.time_per_frame
|
||||
segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
|
||||
segment_size = min(
|
||||
self.feature_extractor.nb_max_frames, content_frames - seek
|
||||
window_end_time = float(
|
||||
(seek + self.feature_extractor.nb_max_frames)
|
||||
* self.feature_extractor.time_per_frame
|
||||
)
|
||||
segment_size = min(
|
||||
self.feature_extractor.nb_max_frames,
|
||||
content_frames - seek,
|
||||
seek_clip_end - seek,
|
||||
)
|
||||
segment = features[:, seek : seek + segment_size]
|
||||
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
||||
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
@@ -486,10 +541,33 @@ class WhisperModel:
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
# anomalous words are very long/short/improbable
|
||||
def word_anomaly_score(word: dict) -> float:
|
||||
probability = word.get("probability", 0.0)
|
||||
duration = word["end"] - word["start"]
|
||||
score = 0.0
|
||||
if probability < 0.15:
|
||||
score += 1.0
|
||||
if duration < 0.133:
|
||||
score += (0.133 - duration) * 15
|
||||
if duration > 2.0:
|
||||
score += duration - 2.0
|
||||
return score
|
||||
|
||||
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||
if segment is None or not segment["words"]:
|
||||
return False
|
||||
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||
words = words[:8]
|
||||
score = sum(word_anomaly_score(w) for w in words)
|
||||
return score >= 3 or score + 0.01 >= len(words)
|
||||
|
||||
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||
return next((s for s in segments if s["words"]), None)
|
||||
|
||||
single_timestamp_ending = (
|
||||
len(tokens) >= 2
|
||||
and tokens[-2] < tokenizer.timestamp_begin
|
||||
and tokens[-1] >= tokenizer.timestamp_begin
|
||||
and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
|
||||
)
|
||||
|
||||
consecutive_timestamps = [
|
||||
@@ -572,18 +650,70 @@ class WhisperModel:
|
||||
last_speech_timestamp=last_speech_timestamp,
|
||||
)
|
||||
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
if len(word_end_timestamps) > 0:
|
||||
last_speech_timestamp = word_end_timestamps[-1]
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
|
||||
)
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
seek = round(last_word_end * self.frames_per_second)
|
||||
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
# skip silence before possible hallucinations
|
||||
if options.hallucination_silence_threshold is not None:
|
||||
threshold = options.hallucination_silence_threshold
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
remaining_duration = window_end_time - last_word_end
|
||||
if remaining_duration > threshold:
|
||||
seek = round(last_word_end * self.frames_per_second)
|
||||
else:
|
||||
seek = previous_seek + segment_size
|
||||
|
||||
# if first segment might be a hallucination, skip leading silence
|
||||
first_segment = next_words_segment(current_segments)
|
||||
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||
gap = first_segment["start"] - time_offset
|
||||
if gap > threshold:
|
||||
seek = previous_seek + round(gap * self.frames_per_second)
|
||||
continue
|
||||
|
||||
# skip silence before any possible hallucination that is surrounded
|
||||
# by silence or more hallucinations
|
||||
hal_last_end = last_speech_timestamp
|
||||
for si in range(len(current_segments)):
|
||||
segment = current_segments[si]
|
||||
if not segment["words"]:
|
||||
continue
|
||||
if is_segment_anomaly(segment):
|
||||
next_segment = next_words_segment(
|
||||
current_segments[si + 1 :]
|
||||
)
|
||||
if next_segment is not None:
|
||||
hal_next_start = next_segment["words"][0]["start"]
|
||||
else:
|
||||
hal_next_start = time_offset + segment_duration
|
||||
silence_before = (
|
||||
segment["start"] - hal_last_end > threshold
|
||||
or segment["start"] < threshold
|
||||
or segment["start"] - time_offset < 2.0
|
||||
)
|
||||
silence_after = (
|
||||
hal_next_start - segment["end"] > threshold
|
||||
or is_segment_anomaly(next_segment)
|
||||
or window_end_time - segment["end"] < 2.0
|
||||
)
|
||||
if silence_before and silence_after:
|
||||
seek = round(
|
||||
max(time_offset + 1, segment["start"])
|
||||
* self.frames_per_second
|
||||
)
|
||||
if content_duration - segment["end"] < threshold:
|
||||
seek = content_frames
|
||||
current_segments[si:] = []
|
||||
break
|
||||
hal_last_end = segment["end"]
|
||||
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None:
|
||||
last_speech_timestamp = last_word_end
|
||||
|
||||
for segment in current_segments:
|
||||
tokens = segment["tokens"]
|
||||
@@ -819,6 +949,7 @@ class WhisperModel:
|
||||
word_durations = np.array([word["end"] - word["start"] for word in alignment])
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||
median_duration = min(0.7, float(median_duration))
|
||||
max_duration = median_duration * 2
|
||||
|
||||
# hack: truncate long words at sentence boundaries.
|
||||
|
||||
@@ -146,3 +146,10 @@ class disabled_tqdm(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["disable"] = True
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def get_end(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||
segments[-1]["end"] if segments else None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user