Fix alignment between the segments and the list of words (#1087)

* Fix alignment between the segments and the list of words

* Ensure the word index does not overflow
This commit is contained in:
Guillaume Klein
2023-03-14 00:34:09 +01:00
committed by GitHub
parent 839639a223
commit 671ac5a4ce

View File

@@ -1,3 +1,4 @@
import itertools
import subprocess import subprocess
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
@@ -290,34 +291,41 @@ def add_word_timestamps(
if len(segments) == 0: if len(segments) == 0:
return return
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot] text_tokens_per_segment = [
[token for token in segment["tokens"] if token < tokenizer.eot]
for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations) merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
segment_lengths = [len(s["tokens"]) for s in segments] word_index = 0
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
for segment in segments: for segment, text_tokens in zip(segments, text_tokens_per_segment):
segment["words"] = [] saved_tokens = 0
words = []
word_boundaries = np.pad(np.cumsum([len(w.tokens) for w in alignment]), (1, 0)) while word_index < len(alignment) and saved_tokens < len(text_tokens):
for i, timing in enumerate(alignment): timing = alignment[word_index]
if timing.word:
segment = segments[token_sources[word_boundaries[i]]] if timing.word:
start = round(time_offset + timing.start, 2) words.append(
end = round(time_offset + timing.end, 2) dict(
segment["words"].append( word=timing.word,
dict( start=round(time_offset + timing.start, 2),
word=timing.word, end=round(time_offset + timing.end, 2),
start=start, probability=timing.probability,
end=end, )
probability=timing.probability,
) )
)
for segment in segments: saved_tokens += len(timing.tokens)
if len(words := segment["words"]) > 0: word_index += 1
if len(words) > 0:
# adjust the segment-level timestamps based on the word-level timestamps # adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"] segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"] segment["end"] = words[-1]["end"]
segment["words"] = words