Add word-level timestamps (#43)

* Add word-level timestamps

* Fix alignment between the segments and the lists of words

* Fix truncated words list when the replacement character is decoded

* Check for empty text_tokens

* Add usage example in the readme

* Update ctranslate2 to 3.9

* Skip empty segment

* Set typing for the new methods
This commit is contained in:
Guillaume Klein
2023-03-15 15:02:28 +01:00
committed by GitHub
parent b41fd05948
commit 8bd013ea99
4 changed files with 314 additions and 8 deletions

View File

@@ -1,4 +1,5 @@
import collections
import itertools
import os
import zlib
@@ -13,7 +14,11 @@ from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))):
pass
class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))):
pass
@@ -42,6 +47,9 @@ class TranscriptionOptions(
"suppress_tokens",
"without_timestamps",
"max_initial_timestamp",
"word_timestamps",
"prepend_punctuations",
"append_punctuations",
),
)
):
@@ -94,6 +102,13 @@ class WhisperModel:
)
self.feature_extractor = FeatureExtractor()
self.num_samples_per_token = self.feature_extractor.hop_length * 2
self.frames_per_second = (
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
)
self.tokens_per_second = (
self.feature_extractor.sampling_rate // self.num_samples_per_token
)
self.input_stride = 2
self.time_precision = 0.02
self.max_length = 448
@@ -125,6 +140,9 @@ class WhisperModel:
suppress_tokens: Optional[List[int]] = [-1],
without_timestamps: bool = False,
max_initial_timestamp: float = 1.0,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
):
"""Transcribes an input file.
@@ -159,6 +177,12 @@ class WhisperModel:
of symbols as defined in the model config.json file.
without_timestamps: Only sample text tokens.
max_initial_timestamp: The initial timestamp cannot be later than this.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
and dynamic time warping, and include the timestamps for each word in each segment.
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
Returns:
A tuple with:
@@ -211,6 +235,9 @@ class WhisperModel:
suppress_tokens=suppress_tokens,
without_timestamps=without_timestamps,
max_initial_timestamp=max_initial_timestamp,
word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
)
segments = self.generate_segments(features, tokenizer, options)
@@ -271,6 +298,7 @@ class WhisperModel:
tokens = result.sequences_ids[0]
previous_seek = seek
current_segments = []
single_timestamp_ending = (
@@ -309,7 +337,12 @@ class WhisperModel:
)
current_segments.append(
dict(start=start_time, end=end_time, tokens=sliced_tokens)
dict(
seek=seek,
start=start_time,
end=end_time,
tokens=sliced_tokens,
)
)
last_slice = current_slice
@@ -333,7 +366,12 @@ class WhisperModel:
duration = last_timestamp_position * self.time_precision
current_segments.append(
dict(start=time_offset, end=time_offset + duration, tokens=tokens)
dict(
seek=seek,
start=time_offset,
end=time_offset + duration,
tokens=tokens,
)
)
seek += segment_size
@@ -341,18 +379,46 @@ class WhisperModel:
if not options.condition_on_previous_text or temperature > 0.5:
prompt_reset_since = len(all_tokens)
if options.word_timestamps:
self.add_word_timestamps(
current_segments,
tokenizer,
segment,
segment_size,
options.prepend_punctuations,
options.append_punctuations,
)
word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
)
if seek_shift > 0:
seek = previous_seek + seek_shift
for segment in current_segments:
tokens = segment["tokens"]
all_tokens.extend(tokens)
text = tokenizer.decode(tokens)
if not text.strip():
if segment["start"] == segment["end"] or not text.strip():
continue
all_tokens.extend(tokens)
yield Segment(
start=segment["start"],
end=segment["end"],
text=text,
words=(
[Word(**word) for word in segment["words"]]
if options.word_timestamps
else None
),
)
def generate_with_fallback(self, segment, prompt, tokenizer, options):
@@ -448,6 +514,126 @@ class WhisperModel:
return prompt
def add_word_timestamps(
self,
segments: List[dict],
tokenizer: Tokenizer,
mel: np.ndarray,
num_frames: int,
prepend_punctuations: str,
append_punctuations: str,
):
if len(segments) == 0:
return
text_tokens_per_segment = [
[token for token in segment["tokens"] if token < tokenizer.eot]
for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = (
segments[0]["seek"]
* self.feature_extractor.hop_length
/ self.feature_extractor.sampling_rate
)
word_index = 0
for segment, text_tokens in zip(segments, text_tokens_per_segment):
saved_tokens = 0
words = []
while word_index < len(alignment) and saved_tokens < len(text_tokens):
timing = alignment[word_index]
if timing["word"]:
words.append(
dict(
word=timing["word"],
start=round(time_offset + timing["start"], 2),
end=round(time_offset + timing["end"], 2),
probability=timing["probability"],
)
)
saved_tokens += len(timing["tokens"])
word_index += 1
if len(words) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]
segment["words"] = words
def find_alignment(
self,
tokenizer: Tokenizer,
text_tokens: List[int],
mel: np.ndarray,
num_frames: int,
median_filter_width: int = 7,
) -> List[dict]:
if len(text_tokens) == 0:
return []
result = self.model.align(
get_input(mel),
tokenizer.sot_sequence,
[text_tokens],
num_frames,
median_filter_width=median_filter_width,
)[0]
text_token_probs = result.text_token_probs
alignments = result.alignments
text_indices = np.array([pair[0] for pair in alignments])
time_indices = np.array([pair[1] for pair in alignments])
words, word_tokens = tokenizer.split_to_word_tokens(
text_tokens + [tokenizer.eot]
)
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / self.tokens_per_second
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
# hack: ensure the first and second word is not longer than twice the median word duration.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
end_times[0] = start_times[1] = boundary
if (
len(word_durations) >= 1
and end_times[0] - start_times[0] > max_duration
):
start_times[0] = max(0, end_times[0] - max_duration)
return [
dict(
word=word, tokens=tokens, start=start, end=end, probability=probability
)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]
def get_input(segment):
segment = np.ascontiguousarray(segment)
@@ -459,3 +645,37 @@ def get_input(segment):
def get_compression_ratio(text):
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
# prepend it to the following word
following["word"] = previous.word + following.word
following["tokens"] = previous.tokens + following.tokens
previous["word"] = ""
previous["tokens"] = []
else:
j = i
i -= 1
# merge appended punctuations
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous["word"].endswith(" ") and following["word"] in appended:
# append it to the previous word
previous["word"] = previous["word"] + following["word"]
previous["tokens"] = previous["tokens"] + following["tokens"]
following["word"] = ""
following["tokens"] = []
else:
i = j
j += 1