diff --git a/README.md b/README.md index 07a5005..e9d3fa6 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,16 @@ for segment in segments: print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) ``` +#### Word-level timestamps + +```python +segments, _ = model.transcribe("audio.mp3", word_timestamps=True) + +for segment in segments: + for word in segment.words: + print("[%.2fs -> %.2fs] %s" % (word.start, word.end, word.word)) +``` + See more model and transcription options in the [`WhisperModel`](https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py) class implementation. ## Comparing performance against other implementations diff --git a/faster_whisper/tokenizer.py b/faster_whisper/tokenizer.py index 32c5a39..bcc7e6c 100644 --- a/faster_whisper/tokenizer.py +++ b/faster_whisper/tokenizer.py @@ -1,5 +1,7 @@ +import string + from functools import cached_property -from typing import List, Optional +from typing import List, Optional, Tuple import tokenizers @@ -21,6 +23,7 @@ class Tokenizer: if self.task is None: raise ValueError("%s is not a valid task" % task) + self.language_code = language self.language = self.tokenizer.token_to_id("<|%s|>" % language) if self.language is None: raise ValueError("%s is not a valid language code" % language) @@ -67,3 +70,76 @@ class Tokenizer: def decode(self, tokens: List[int]) -> str: text_tokens = [token for token in tokens if token < self.eot] return self.tokenizer.decode(text_tokens) + + def decode_with_timestamps(self, tokens: List[int]) -> str: + outputs = [[]] + + for token in tokens: + if token >= self.timestamp_begin: + timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" + outputs.append(timestamp) + outputs.append([]) + else: + outputs[-1].append(token) + + return "".join( + [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] + ) + + def split_to_word_tokens( + self, tokens: List[int] + ) -> Tuple[List[str], List[List[int]]]: + if self.language_code in {"zh", "ja", "th", "lo", "my"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + return self.split_tokens_on_unicode(tokens) + + return self.split_tokens_on_spaces(tokens) + + def split_tokens_on_unicode( + self, tokens: List[int] + ) -> Tuple[List[str], List[List[int]]]: + decoded_full = self.decode_with_timestamps(tokens) + replacement_char = "\ufffd" + + words = [] + word_tokens = [] + current_tokens = [] + unicode_offset = 0 + + for token in tokens: + current_tokens.append(token) + decoded = self.decode_with_timestamps(current_tokens) + + if ( + replacement_char not in decoded + or decoded_full[unicode_offset + decoded.index(replacement_char)] + == replacement_char + ): + words.append(decoded) + word_tokens.append(current_tokens) + current_tokens = [] + unicode_offset += len(decoded) + + return words, word_tokens + + def split_tokens_on_spaces( + self, tokens: List[int] + ) -> Tuple[List[str], List[List[int]]]: + subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) + words = [] + word_tokens = [] + + for subword, subword_tokens in zip(subwords, subword_tokens_list): + special = subword_tokens[0] >= self.eot + with_space = subword.startswith(" ") + punctuation = subword.strip() in string.punctuation + if special or with_space or punctuation or len(words) == 0: + words.append(subword) + word_tokens.append(subword_tokens) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + + return words, word_tokens diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 419111b..dba3402 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -1,4 +1,5 @@ import collections +import itertools import os import zlib @@ -13,7 +14,11 @@ from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import Tokenizer -class Segment(collections.namedtuple("Segment", ("start", "end", "text"))): +class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))): + pass + + +class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))): pass @@ -42,6 +47,9 @@ class TranscriptionOptions( "suppress_tokens", "without_timestamps", "max_initial_timestamp", + "word_timestamps", + "prepend_punctuations", + "append_punctuations", ), ) ): @@ -94,6 +102,13 @@ class WhisperModel: ) self.feature_extractor = FeatureExtractor() + self.num_samples_per_token = self.feature_extractor.hop_length * 2 + self.frames_per_second = ( + self.feature_extractor.sampling_rate // self.feature_extractor.hop_length + ) + self.tokens_per_second = ( + self.feature_extractor.sampling_rate // self.num_samples_per_token + ) self.input_stride = 2 self.time_precision = 0.02 self.max_length = 448 @@ -125,6 +140,9 @@ class WhisperModel: suppress_tokens: Optional[List[int]] = [-1], without_timestamps: bool = False, max_initial_timestamp: float = 1.0, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", ): """Transcribes an input file. @@ -159,6 +177,12 @@ class WhisperModel: of symbols as defined in the model config.json file. without_timestamps: Only sample text tokens. max_initial_timestamp: The initial timestamp cannot be later than this. + word_timestamps: Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + prepend_punctuations: If word_timestamps is True, merge these punctuation symbols + with the next word + append_punctuations: If word_timestamps is True, merge these punctuation symbols + with the previous word Returns: A tuple with: @@ -211,6 +235,9 @@ class WhisperModel: suppress_tokens=suppress_tokens, without_timestamps=without_timestamps, max_initial_timestamp=max_initial_timestamp, + word_timestamps=word_timestamps, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, ) segments = self.generate_segments(features, tokenizer, options) @@ -271,6 +298,7 @@ class WhisperModel: tokens = result.sequences_ids[0] + previous_seek = seek current_segments = [] single_timestamp_ending = ( @@ -309,7 +337,12 @@ class WhisperModel: ) current_segments.append( - dict(start=start_time, end=end_time, tokens=sliced_tokens) + dict( + seek=seek, + start=start_time, + end=end_time, + tokens=sliced_tokens, + ) ) last_slice = current_slice @@ -333,7 +366,12 @@ class WhisperModel: duration = last_timestamp_position * self.time_precision current_segments.append( - dict(start=time_offset, end=time_offset + duration, tokens=tokens) + dict( + seek=seek, + start=time_offset, + end=time_offset + duration, + tokens=tokens, + ) ) seek += segment_size @@ -341,18 +379,46 @@ class WhisperModel: if not options.condition_on_previous_text or temperature > 0.5: prompt_reset_since = len(all_tokens) + if options.word_timestamps: + self.add_word_timestamps( + current_segments, + tokenizer, + segment, + segment_size, + options.prepend_punctuations, + options.append_punctuations, + ) + + word_end_timestamps = [ + w["end"] for s in current_segments for w in s["words"] + ] + + if not single_timestamp_ending and len(word_end_timestamps) > 0: + seek_shift = round( + (word_end_timestamps[-1] - time_offset) * self.frames_per_second + ) + + if seek_shift > 0: + seek = previous_seek + seek_shift + for segment in current_segments: tokens = segment["tokens"] - all_tokens.extend(tokens) - text = tokenizer.decode(tokens) - if not text.strip(): + + if segment["start"] == segment["end"] or not text.strip(): continue + all_tokens.extend(tokens) + yield Segment( start=segment["start"], end=segment["end"], text=text, + words=( + [Word(**word) for word in segment["words"]] + if options.word_timestamps + else None + ), ) def generate_with_fallback(self, segment, prompt, tokenizer, options): @@ -448,6 +514,126 @@ class WhisperModel: return prompt + def add_word_timestamps( + self, + segments: List[dict], + tokenizer: Tokenizer, + mel: np.ndarray, + num_frames: int, + prepend_punctuations: str, + append_punctuations: str, + ): + if len(segments) == 0: + return + + text_tokens_per_segment = [ + [token for token in segment["tokens"] if token < tokenizer.eot] + for segment in segments + ] + + text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) + alignment = self.find_alignment(tokenizer, text_tokens, mel, num_frames) + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + + time_offset = ( + segments[0]["seek"] + * self.feature_extractor.hop_length + / self.feature_extractor.sampling_rate + ) + + word_index = 0 + + for segment, text_tokens in zip(segments, text_tokens_per_segment): + saved_tokens = 0 + words = [] + + while word_index < len(alignment) and saved_tokens < len(text_tokens): + timing = alignment[word_index] + + if timing["word"]: + words.append( + dict( + word=timing["word"], + start=round(time_offset + timing["start"], 2), + end=round(time_offset + timing["end"], 2), + probability=timing["probability"], + ) + ) + + saved_tokens += len(timing["tokens"]) + word_index += 1 + + if len(words) > 0: + # adjust the segment-level timestamps based on the word-level timestamps + segment["start"] = words[0]["start"] + segment["end"] = words[-1]["end"] + + segment["words"] = words + + def find_alignment( + self, + tokenizer: Tokenizer, + text_tokens: List[int], + mel: np.ndarray, + num_frames: int, + median_filter_width: int = 7, + ) -> List[dict]: + if len(text_tokens) == 0: + return [] + + result = self.model.align( + get_input(mel), + tokenizer.sot_sequence, + [text_tokens], + num_frames, + median_filter_width=median_filter_width, + )[0] + + text_token_probs = result.text_token_probs + + alignments = result.alignments + text_indices = np.array([pair[0] for pair in alignments]) + time_indices = np.array([pair[1] for pair in alignments]) + + words, word_tokens = tokenizer.split_to_word_tokens( + text_tokens + [tokenizer.eot] + ) + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) + + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] / self.tokens_per_second + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + # hack: ensure the first and second word is not longer than twice the median word duration. + # a better segmentation algorithm based on VAD should be able to replace this. + word_durations = end_times - start_times + word_durations = word_durations[word_durations.nonzero()] + if len(word_durations) > 0: + median_duration = np.median(word_durations) + max_duration = median_duration * 2 + if len(word_durations) >= 2 and word_durations[1] > max_duration: + boundary = max(end_times[2] / 2, end_times[2] - max_duration) + end_times[0] = start_times[1] = boundary + if ( + len(word_durations) >= 1 + and end_times[0] - start_times[0] > max_duration + ): + start_times[0] = max(0, end_times[0] - max_duration) + + return [ + dict( + word=word, tokens=tokens, start=start, end=end, probability=probability + ) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + def get_input(segment): segment = np.ascontiguousarray(segment) @@ -459,3 +645,37 @@ def get_input(segment): def get_compression_ratio(text): text_bytes = text.encode("utf-8") return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def merge_punctuations(alignment: List[dict], prepended: str, appended: str): + # merge prepended punctuations + i = len(alignment) - 2 + j = len(alignment) - 1 + while i >= 0: + previous = alignment[i] + following = alignment[j] + if previous["word"].startswith(" ") and previous["word"].strip() in prepended: + # prepend it to the following word + following["word"] = previous.word + following.word + following["tokens"] = previous.tokens + following.tokens + previous["word"] = "" + previous["tokens"] = [] + else: + j = i + i -= 1 + + # merge appended punctuations + i = 0 + j = 1 + while j < len(alignment): + previous = alignment[i] + following = alignment[j] + if not previous["word"].endswith(" ") and following["word"] in appended: + # append it to the previous word + previous["word"] = previous["word"] + following["word"] + previous["tokens"] = previous["tokens"] + following["tokens"] + following["word"] = "" + following["tokens"] = [] + else: + i = j + j += 1 diff --git a/requirements.txt b/requirements.txt index 3f1fdfd..4f981a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ av==10.* -ctranslate2>=3.8,<4 +ctranslate2>=3.9,<4 tokenizers==0.13.*