Add word-level timestamps (#43)
* Add word-level timestamps * Fix alignment between the segments and the lists of words * Fix truncated words list when the replacement character is decoded * Check for empty text_tokens * Add usage example in the readme * Update ctranslate2 to 3.9 * Skip empty segment * Set typing for the new methods
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import collections
|
||||
import itertools
|
||||
import os
|
||||
import zlib
|
||||
|
||||
@@ -13,7 +14,11 @@ from faster_whisper.feature_extractor import FeatureExtractor
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
|
||||
class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))):
|
||||
pass
|
||||
|
||||
|
||||
class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))):
|
||||
pass
|
||||
|
||||
|
||||
@@ -42,6 +47,9 @@ class TranscriptionOptions(
|
||||
"suppress_tokens",
|
||||
"without_timestamps",
|
||||
"max_initial_timestamp",
|
||||
"word_timestamps",
|
||||
"prepend_punctuations",
|
||||
"append_punctuations",
|
||||
),
|
||||
)
|
||||
):
|
||||
@@ -94,6 +102,13 @@ class WhisperModel:
|
||||
)
|
||||
|
||||
self.feature_extractor = FeatureExtractor()
|
||||
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
||||
self.frames_per_second = (
|
||||
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
|
||||
)
|
||||
self.tokens_per_second = (
|
||||
self.feature_extractor.sampling_rate // self.num_samples_per_token
|
||||
)
|
||||
self.input_stride = 2
|
||||
self.time_precision = 0.02
|
||||
self.max_length = 448
|
||||
@@ -125,6 +140,9 @@ class WhisperModel:
|
||||
suppress_tokens: Optional[List[int]] = [-1],
|
||||
without_timestamps: bool = False,
|
||||
max_initial_timestamp: float = 1.0,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
):
|
||||
"""Transcribes an input file.
|
||||
|
||||
@@ -159,6 +177,12 @@ class WhisperModel:
|
||||
of symbols as defined in the model config.json file.
|
||||
without_timestamps: Only sample text tokens.
|
||||
max_initial_timestamp: The initial timestamp cannot be later than this.
|
||||
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
||||
and dynamic time warping, and include the timestamps for each word in each segment.
|
||||
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||
with the next word
|
||||
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||
with the previous word
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
@@ -211,6 +235,9 @@ class WhisperModel:
|
||||
suppress_tokens=suppress_tokens,
|
||||
without_timestamps=without_timestamps,
|
||||
max_initial_timestamp=max_initial_timestamp,
|
||||
word_timestamps=word_timestamps,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options)
|
||||
@@ -271,6 +298,7 @@ class WhisperModel:
|
||||
|
||||
tokens = result.sequences_ids[0]
|
||||
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
single_timestamp_ending = (
|
||||
@@ -309,7 +337,12 @@ class WhisperModel:
|
||||
)
|
||||
|
||||
current_segments.append(
|
||||
dict(start=start_time, end=end_time, tokens=sliced_tokens)
|
||||
dict(
|
||||
seek=seek,
|
||||
start=start_time,
|
||||
end=end_time,
|
||||
tokens=sliced_tokens,
|
||||
)
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
@@ -333,7 +366,12 @@ class WhisperModel:
|
||||
duration = last_timestamp_position * self.time_precision
|
||||
|
||||
current_segments.append(
|
||||
dict(start=time_offset, end=time_offset + duration, tokens=tokens)
|
||||
dict(
|
||||
seek=seek,
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
)
|
||||
)
|
||||
|
||||
seek += segment_size
|
||||
@@ -341,18 +379,46 @@ class WhisperModel:
|
||||
if not options.condition_on_previous_text or temperature > 0.5:
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
if options.word_timestamps:
|
||||
self.add_word_timestamps(
|
||||
current_segments,
|
||||
tokenizer,
|
||||
segment,
|
||||
segment_size,
|
||||
options.prepend_punctuations,
|
||||
options.append_punctuations,
|
||||
)
|
||||
|
||||
word_end_timestamps = [
|
||||
w["end"] for s in current_segments for w in s["words"]
|
||||
]
|
||||
|
||||
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||
seek_shift = round(
|
||||
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
|
||||
)
|
||||
|
||||
if seek_shift > 0:
|
||||
seek = previous_seek + seek_shift
|
||||
|
||||
for segment in current_segments:
|
||||
tokens = segment["tokens"]
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
text = tokenizer.decode(tokens)
|
||||
if not text.strip():
|
||||
|
||||
if segment["start"] == segment["end"] or not text.strip():
|
||||
continue
|
||||
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
yield Segment(
|
||||
start=segment["start"],
|
||||
end=segment["end"],
|
||||
text=text,
|
||||
words=(
|
||||
[Word(**word) for word in segment["words"]]
|
||||
if options.word_timestamps
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
||||
@@ -448,6 +514,126 @@ class WhisperModel:
|
||||
|
||||
return prompt
|
||||
|
||||
def add_word_timestamps(
|
||||
self,
|
||||
segments: List[dict],
|
||||
tokenizer: Tokenizer,
|
||||
mel: np.ndarray,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str,
|
||||
append_punctuations: str,
|
||||
):
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
text_tokens_per_segment = [
|
||||
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||
alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames)
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = (
|
||||
segments[0]["seek"]
|
||||
* self.feature_extractor.hop_length
|
||||
/ self.feature_extractor.sampling_rate
|
||||
)
|
||||
|
||||
word_index = 0
|
||||
|
||||
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||||
saved_tokens = 0
|
||||
words = []
|
||||
|
||||
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||||
timing = alignment[word_index]
|
||||
|
||||
if timing["word"]:
|
||||
words.append(
|
||||
dict(
|
||||
word=timing["word"],
|
||||
start=round(time_offset + timing["start"], 2),
|
||||
end=round(time_offset + timing["end"], 2),
|
||||
probability=timing["probability"],
|
||||
)
|
||||
)
|
||||
|
||||
saved_tokens += len(timing["tokens"])
|
||||
word_index += 1
|
||||
|
||||
if len(words) > 0:
|
||||
# adjust the segment-level timestamps based on the word-level timestamps
|
||||
segment["start"] = words[0]["start"]
|
||||
segment["end"] = words[-1]["end"]
|
||||
|
||||
segment["words"] = words
|
||||
|
||||
def find_alignment(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: np.ndarray,
|
||||
num_frames: int,
|
||||
median_filter_width: int = 7,
|
||||
) -> List[dict]:
|
||||
if len(text_tokens) == 0:
|
||||
return []
|
||||
|
||||
result = self.model.align(
|
||||
get_input(mel),
|
||||
tokenizer.sot_sequence,
|
||||
[text_tokens],
|
||||
num_frames,
|
||||
median_filter_width=median_filter_width,
|
||||
)[0]
|
||||
|
||||
text_token_probs = result.text_token_probs
|
||||
|
||||
alignments = result.alignments
|
||||
text_indices = np.array([pair[0] for pair in alignments])
|
||||
time_indices = np.array([pair[1] for pair in alignments])
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(
|
||||
text_tokens + [tokenizer.eot]
|
||||
)
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] / self.tokens_per_second
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
np.mean(text_token_probs[i:j])
|
||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
# hack: ensure the first and second word is not longer than twice the median word duration.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
word_durations = end_times - start_times
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
if len(word_durations) > 0:
|
||||
median_duration = np.median(word_durations)
|
||||
max_duration = median_duration * 2
|
||||
if len(word_durations) >= 2 and word_durations[1] > max_duration:
|
||||
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||
end_times[0] = start_times[1] = boundary
|
||||
if (
|
||||
len(word_durations) >= 1
|
||||
and end_times[0] - start_times[0] > max_duration
|
||||
):
|
||||
start_times[0] = max(0, end_times[0] - max_duration)
|
||||
|
||||
return [
|
||||
dict(
|
||||
word=word, tokens=tokens, start=start, end=end, probability=probability
|
||||
)
|
||||
for word, tokens, start, end, probability in zip(
|
||||
words, word_tokens, start_times, end_times, word_probabilities
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_input(segment):
|
||||
segment = np.ascontiguousarray(segment)
|
||||
@@ -459,3 +645,37 @@ def get_input(segment):
|
||||
def get_compression_ratio(text):
|
||||
text_bytes = text.encode("utf-8")
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
|
||||
# merge prepended punctuations
|
||||
i = len(alignment) - 2
|
||||
j = len(alignment) - 1
|
||||
while i >= 0:
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
|
||||
# prepend it to the following word
|
||||
following["word"] = previous.word + following.word
|
||||
following["tokens"] = previous.tokens + following.tokens
|
||||
previous["word"] = ""
|
||||
previous["tokens"] = []
|
||||
else:
|
||||
j = i
|
||||
i -= 1
|
||||
|
||||
# merge appended punctuations
|
||||
i = 0
|
||||
j = 1
|
||||
while j < len(alignment):
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if not previous["word"].endswith(" ") and following["word"] in appended:
|
||||
# append it to the previous word
|
||||
previous["word"] = previous["word"] + following["word"]
|
||||
previous["tokens"] = previous["tokens"] + following["tokens"]
|
||||
following["word"] = ""
|
||||
following["tokens"] = []
|
||||
else:
|
||||
i = j
|
||||
j += 1
|
||||
|
||||
Reference in New Issue
Block a user