Add word-level timestamps (#43)
* Add word-level timestamps * Fix alignment between the segments and the lists of words * Fix truncated words list when the replacement character is decoded * Check for empty text_tokens * Add usage example in the readme * Update ctranslate2 to 3.9 * Skip empty segment * Set typing for the new methods
This commit is contained in:
10
README.md
10
README.md
@@ -99,6 +99,16 @@ for segment in segments:
|
|||||||
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Word-level timestamps
|
||||||
|
|
||||||
|
```python
|
||||||
|
segments, _ = model.transcribe("audio.mp3", word_timestamps=True)
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
for word in segment.words:
|
||||||
|
print("[%.2fs -> %.2fs] %s" % (word.start, word.end, word.word))
|
||||||
|
```
|
||||||
|
|
||||||
See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation.
|
See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation.
|
||||||
|
|
||||||
## Comparing performance against other implementations
|
## Comparing performance against other implementations
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
|
import string
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import tokenizers
|
import tokenizers
|
||||||
|
|
||||||
@@ -21,6 +23,7 @@ class Tokenizer:
|
|||||||
if self.task is None:
|
if self.task is None:
|
||||||
raise ValueError("%s is not a valid task" % task)
|
raise ValueError("%s is not a valid task" % task)
|
||||||
|
|
||||||
|
self.language_code = language
|
||||||
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
|
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
|
||||||
if self.language is None:
|
if self.language is None:
|
||||||
raise ValueError("%s is not a valid language code" % language)
|
raise ValueError("%s is not a valid language code" % language)
|
||||||
@@ -67,3 +70,76 @@ class Tokenizer:
|
|||||||
def decode(self, tokens: List[int]) -> str:
|
def decode(self, tokens: List[int]) -> str:
|
||||||
text_tokens = [token for token in tokens if token < self.eot]
|
text_tokens = [token for token in tokens if token < self.eot]
|
||||||
return self.tokenizer.decode(text_tokens)
|
return self.tokenizer.decode(text_tokens)
|
||||||
|
|
||||||
|
def decode_with_timestamps(self, tokens: List[int]) -> str:
|
||||||
|
outputs = [[]]
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if token >= self.timestamp_begin:
|
||||||
|
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
||||||
|
outputs.append(timestamp)
|
||||||
|
outputs.append([])
|
||||||
|
else:
|
||||||
|
outputs[-1].append(token)
|
||||||
|
|
||||||
|
return "".join(
|
||||||
|
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_to_word_tokens(
|
||||||
|
self, tokens: List[int]
|
||||||
|
) -> Tuple[List[str], List[List[int]]]:
|
||||||
|
if self.language_code in {"zh", "ja", "th", "lo", "my"}:
|
||||||
|
# These languages don't typically use spaces, so it is difficult to split words
|
||||||
|
# without morpheme analysis. Here, we instead split words at any
|
||||||
|
# position where the tokens are decoded as valid unicode points
|
||||||
|
return self.split_tokens_on_unicode(tokens)
|
||||||
|
|
||||||
|
return self.split_tokens_on_spaces(tokens)
|
||||||
|
|
||||||
|
def split_tokens_on_unicode(
|
||||||
|
self, tokens: List[int]
|
||||||
|
) -> Tuple[List[str], List[List[int]]]:
|
||||||
|
decoded_full = self.decode_with_timestamps(tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
words = []
|
||||||
|
word_tokens = []
|
||||||
|
current_tokens = []
|
||||||
|
unicode_offset = 0
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
current_tokens.append(token)
|
||||||
|
decoded = self.decode_with_timestamps(current_tokens)
|
||||||
|
|
||||||
|
if (
|
||||||
|
replacement_char not in decoded
|
||||||
|
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||||
|
== replacement_char
|
||||||
|
):
|
||||||
|
words.append(decoded)
|
||||||
|
word_tokens.append(current_tokens)
|
||||||
|
current_tokens = []
|
||||||
|
unicode_offset += len(decoded)
|
||||||
|
|
||||||
|
return words, word_tokens
|
||||||
|
|
||||||
|
def split_tokens_on_spaces(
|
||||||
|
self, tokens: List[int]
|
||||||
|
) -> Tuple[List[str], List[List[int]]]:
|
||||||
|
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||||
|
words = []
|
||||||
|
word_tokens = []
|
||||||
|
|
||||||
|
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||||
|
special = subword_tokens[0] >= self.eot
|
||||||
|
with_space = subword.startswith(" ")
|
||||||
|
punctuation = subword.strip() in string.punctuation
|
||||||
|
if special or with_space or punctuation or len(words) == 0:
|
||||||
|
words.append(subword)
|
||||||
|
word_tokens.append(subword_tokens)
|
||||||
|
else:
|
||||||
|
words[-1] = words[-1] + subword
|
||||||
|
word_tokens[-1].extend(subword_tokens)
|
||||||
|
|
||||||
|
return words, word_tokens
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import collections
|
import collections
|
||||||
|
import itertools
|
||||||
import os
|
import os
|
||||||
import zlib
|
import zlib
|
||||||
|
|
||||||
@@ -13,7 +14,11 @@ from faster_whisper.feature_extractor import FeatureExtractor
|
|||||||
from faster_whisper.tokenizer import Tokenizer
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
class Segment(collections.namedtuple("Segment", ("start", "end", "text"))):
|
class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -42,6 +47,9 @@ class TranscriptionOptions(
|
|||||||
"suppress_tokens",
|
"suppress_tokens",
|
||||||
"without_timestamps",
|
"without_timestamps",
|
||||||
"max_initial_timestamp",
|
"max_initial_timestamp",
|
||||||
|
"word_timestamps",
|
||||||
|
"prepend_punctuations",
|
||||||
|
"append_punctuations",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@@ -94,6 +102,13 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.feature_extractor = FeatureExtractor()
|
self.feature_extractor = FeatureExtractor()
|
||||||
|
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
||||||
|
self.frames_per_second = (
|
||||||
|
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
|
||||||
|
)
|
||||||
|
self.tokens_per_second = (
|
||||||
|
self.feature_extractor.sampling_rate // self.num_samples_per_token
|
||||||
|
)
|
||||||
self.input_stride = 2
|
self.input_stride = 2
|
||||||
self.time_precision = 0.02
|
self.time_precision = 0.02
|
||||||
self.max_length = 448
|
self.max_length = 448
|
||||||
@@ -125,6 +140,9 @@ class WhisperModel:
|
|||||||
suppress_tokens: Optional[List[int]] = [-1],
|
suppress_tokens: Optional[List[int]] = [-1],
|
||||||
without_timestamps: bool = False,
|
without_timestamps: bool = False,
|
||||||
max_initial_timestamp: float = 1.0,
|
max_initial_timestamp: float = 1.0,
|
||||||
|
word_timestamps: bool = False,
|
||||||
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
):
|
):
|
||||||
"""Transcribes an input file.
|
"""Transcribes an input file.
|
||||||
|
|
||||||
@@ -159,6 +177,12 @@ class WhisperModel:
|
|||||||
of symbols as defined in the model config.json file.
|
of symbols as defined in the model config.json file.
|
||||||
without_timestamps: Only sample text tokens.
|
without_timestamps: Only sample text tokens.
|
||||||
max_initial_timestamp: The initial timestamp cannot be later than this.
|
max_initial_timestamp: The initial timestamp cannot be later than this.
|
||||||
|
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
||||||
|
and dynamic time warping, and include the timestamps for each word in each segment.
|
||||||
|
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||||
|
with the next word
|
||||||
|
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
||||||
|
with the previous word
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple with:
|
A tuple with:
|
||||||
@@ -211,6 +235,9 @@ class WhisperModel:
|
|||||||
suppress_tokens=suppress_tokens,
|
suppress_tokens=suppress_tokens,
|
||||||
without_timestamps=without_timestamps,
|
without_timestamps=without_timestamps,
|
||||||
max_initial_timestamp=max_initial_timestamp,
|
max_initial_timestamp=max_initial_timestamp,
|
||||||
|
word_timestamps=word_timestamps,
|
||||||
|
prepend_punctuations=prepend_punctuations,
|
||||||
|
append_punctuations=append_punctuations,
|
||||||
)
|
)
|
||||||
|
|
||||||
segments = self.generate_segments(features, tokenizer, options)
|
segments = self.generate_segments(features, tokenizer, options)
|
||||||
@@ -271,6 +298,7 @@ class WhisperModel:
|
|||||||
|
|
||||||
tokens = result.sequences_ids[0]
|
tokens = result.sequences_ids[0]
|
||||||
|
|
||||||
|
previous_seek = seek
|
||||||
current_segments = []
|
current_segments = []
|
||||||
|
|
||||||
single_timestamp_ending = (
|
single_timestamp_ending = (
|
||||||
@@ -309,7 +337,12 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
current_segments.append(
|
current_segments.append(
|
||||||
dict(start=start_time, end=end_time, tokens=sliced_tokens)
|
dict(
|
||||||
|
seek=seek,
|
||||||
|
start=start_time,
|
||||||
|
end=end_time,
|
||||||
|
tokens=sliced_tokens,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
last_slice = current_slice
|
last_slice = current_slice
|
||||||
|
|
||||||
@@ -333,7 +366,12 @@ class WhisperModel:
|
|||||||
duration = last_timestamp_position * self.time_precision
|
duration = last_timestamp_position * self.time_precision
|
||||||
|
|
||||||
current_segments.append(
|
current_segments.append(
|
||||||
dict(start=time_offset, end=time_offset + duration, tokens=tokens)
|
dict(
|
||||||
|
seek=seek,
|
||||||
|
start=time_offset,
|
||||||
|
end=time_offset + duration,
|
||||||
|
tokens=tokens,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
@@ -341,18 +379,46 @@ class WhisperModel:
|
|||||||
if not options.condition_on_previous_text or temperature > 0.5:
|
if not options.condition_on_previous_text or temperature > 0.5:
|
||||||
prompt_reset_since = len(all_tokens)
|
prompt_reset_since = len(all_tokens)
|
||||||
|
|
||||||
|
if options.word_timestamps:
|
||||||
|
self.add_word_timestamps(
|
||||||
|
current_segments,
|
||||||
|
tokenizer,
|
||||||
|
segment,
|
||||||
|
segment_size,
|
||||||
|
options.prepend_punctuations,
|
||||||
|
options.append_punctuations,
|
||||||
|
)
|
||||||
|
|
||||||
|
word_end_timestamps = [
|
||||||
|
w["end"] for s in current_segments for w in s["words"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
||||||
|
seek_shift = round(
|
||||||
|
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
|
||||||
|
)
|
||||||
|
|
||||||
|
if seek_shift > 0:
|
||||||
|
seek = previous_seek + seek_shift
|
||||||
|
|
||||||
for segment in current_segments:
|
for segment in current_segments:
|
||||||
tokens = segment["tokens"]
|
tokens = segment["tokens"]
|
||||||
all_tokens.extend(tokens)
|
|
||||||
|
|
||||||
text = tokenizer.decode(tokens)
|
text = tokenizer.decode(tokens)
|
||||||
if not text.strip():
|
|
||||||
|
if segment["start"] == segment["end"] or not text.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
all_tokens.extend(tokens)
|
||||||
|
|
||||||
yield Segment(
|
yield Segment(
|
||||||
start=segment["start"],
|
start=segment["start"],
|
||||||
end=segment["end"],
|
end=segment["end"],
|
||||||
text=text,
|
text=text,
|
||||||
|
words=(
|
||||||
|
[Word(**word) for word in segment["words"]]
|
||||||
|
if options.word_timestamps
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
||||||
@@ -448,6 +514,126 @@ class WhisperModel:
|
|||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def add_word_timestamps(
|
||||||
|
self,
|
||||||
|
segments: List[dict],
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
mel: np.ndarray,
|
||||||
|
num_frames: int,
|
||||||
|
prepend_punctuations: str,
|
||||||
|
append_punctuations: str,
|
||||||
|
):
|
||||||
|
if len(segments) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
text_tokens_per_segment = [
|
||||||
|
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||||||
|
for segment in segments
|
||||||
|
]
|
||||||
|
|
||||||
|
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||||
|
alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames)
|
||||||
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||||
|
|
||||||
|
time_offset = (
|
||||||
|
segments[0]["seek"]
|
||||||
|
* self.feature_extractor.hop_length
|
||||||
|
/ self.feature_extractor.sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
word_index = 0
|
||||||
|
|
||||||
|
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||||||
|
saved_tokens = 0
|
||||||
|
words = []
|
||||||
|
|
||||||
|
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||||||
|
timing = alignment[word_index]
|
||||||
|
|
||||||
|
if timing["word"]:
|
||||||
|
words.append(
|
||||||
|
dict(
|
||||||
|
word=timing["word"],
|
||||||
|
start=round(time_offset + timing["start"], 2),
|
||||||
|
end=round(time_offset + timing["end"], 2),
|
||||||
|
probability=timing["probability"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_tokens += len(timing["tokens"])
|
||||||
|
word_index += 1
|
||||||
|
|
||||||
|
if len(words) > 0:
|
||||||
|
# adjust the segment-level timestamps based on the word-level timestamps
|
||||||
|
segment["start"] = words[0]["start"]
|
||||||
|
segment["end"] = words[-1]["end"]
|
||||||
|
|
||||||
|
segment["words"] = words
|
||||||
|
|
||||||
|
def find_alignment(
|
||||||
|
self,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
text_tokens: List[int],
|
||||||
|
mel: np.ndarray,
|
||||||
|
num_frames: int,
|
||||||
|
median_filter_width: int = 7,
|
||||||
|
) -> List[dict]:
|
||||||
|
if len(text_tokens) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = self.model.align(
|
||||||
|
get_input(mel),
|
||||||
|
tokenizer.sot_sequence,
|
||||||
|
[text_tokens],
|
||||||
|
num_frames,
|
||||||
|
median_filter_width=median_filter_width,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
text_token_probs = result.text_token_probs
|
||||||
|
|
||||||
|
alignments = result.alignments
|
||||||
|
text_indices = np.array([pair[0] for pair in alignments])
|
||||||
|
time_indices = np.array([pair[1] for pair in alignments])
|
||||||
|
|
||||||
|
words, word_tokens = tokenizer.split_to_word_tokens(
|
||||||
|
text_tokens + [tokenizer.eot]
|
||||||
|
)
|
||||||
|
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||||
|
|
||||||
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
|
jump_times = time_indices[jumps] / self.tokens_per_second
|
||||||
|
start_times = jump_times[word_boundaries[:-1]]
|
||||||
|
end_times = jump_times[word_boundaries[1:]]
|
||||||
|
word_probabilities = [
|
||||||
|
np.mean(text_token_probs[i:j])
|
||||||
|
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||||
|
]
|
||||||
|
|
||||||
|
# hack: ensure the first and second word is not longer than twice the median word duration.
|
||||||
|
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||||
|
word_durations = end_times - start_times
|
||||||
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
|
if len(word_durations) > 0:
|
||||||
|
median_duration = np.median(word_durations)
|
||||||
|
max_duration = median_duration * 2
|
||||||
|
if len(word_durations) >= 2 and word_durations[1] > max_duration:
|
||||||
|
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||||
|
end_times[0] = start_times[1] = boundary
|
||||||
|
if (
|
||||||
|
len(word_durations) >= 1
|
||||||
|
and end_times[0] - start_times[0] > max_duration
|
||||||
|
):
|
||||||
|
start_times[0] = max(0, end_times[0] - max_duration)
|
||||||
|
|
||||||
|
return [
|
||||||
|
dict(
|
||||||
|
word=word, tokens=tokens, start=start, end=end, probability=probability
|
||||||
|
)
|
||||||
|
for word, tokens, start, end, probability in zip(
|
||||||
|
words, word_tokens, start_times, end_times, word_probabilities
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_input(segment):
|
def get_input(segment):
|
||||||
segment = np.ascontiguousarray(segment)
|
segment = np.ascontiguousarray(segment)
|
||||||
@@ -459,3 +645,37 @@ def get_input(segment):
|
|||||||
def get_compression_ratio(text):
|
def get_compression_ratio(text):
|
||||||
text_bytes = text.encode("utf-8")
|
text_bytes = text.encode("utf-8")
|
||||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||||
|
|
||||||
|
|
||||||
|
def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
|
||||||
|
# merge prepended punctuations
|
||||||
|
i = len(alignment) - 2
|
||||||
|
j = len(alignment) - 1
|
||||||
|
while i >= 0:
|
||||||
|
previous = alignment[i]
|
||||||
|
following = alignment[j]
|
||||||
|
if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
|
||||||
|
# prepend it to the following word
|
||||||
|
following["word"] = previous.word + following.word
|
||||||
|
following["tokens"] = previous.tokens + following.tokens
|
||||||
|
previous["word"] = ""
|
||||||
|
previous["tokens"] = []
|
||||||
|
else:
|
||||||
|
j = i
|
||||||
|
i -= 1
|
||||||
|
|
||||||
|
# merge appended punctuations
|
||||||
|
i = 0
|
||||||
|
j = 1
|
||||||
|
while j < len(alignment):
|
||||||
|
previous = alignment[i]
|
||||||
|
following = alignment[j]
|
||||||
|
if not previous["word"].endswith(" ") and following["word"] in appended:
|
||||||
|
# append it to the previous word
|
||||||
|
previous["word"] = previous["word"] + following["word"]
|
||||||
|
previous["tokens"] = previous["tokens"] + following["tokens"]
|
||||||
|
following["word"] = ""
|
||||||
|
following["tokens"] = []
|
||||||
|
else:
|
||||||
|
i = j
|
||||||
|
j += 1
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
av==10.*
|
av==10.*
|
||||||
ctranslate2>=3.8,<4
|
ctranslate2>=3.9,<4
|
||||||
tokenizers==0.13.*
|
tokenizers==0.13.*
|
||||||
|
|||||||
Reference in New Issue
Block a user