Use a dict to represent intermediate segments

This commit is contained in:
Guillaume Klein
2023-03-09 11:53:55 +01:00
parent 6a84df400f
commit f0a21ea916

View File

@@ -223,20 +223,6 @@ class WhisperModel:
return segments, audio_info
def generate_segments(self, features, options):
tokenized_segments = self.generate_tokenized_segments(features, options)
for start, end, tokens in tokenized_segments:
text = self.decode_text_tokens(tokens)
if not text.strip():
continue
yield Segment(
start=start,
end=end,
text=text,
)
def generate_tokenized_segments(self, features, options):
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
seek = 0
all_tokens = []
@@ -321,7 +307,9 @@ class WhisperModel:
time_offset + end_timestamp_position * self.time_precision
)
current_segments.append((start_time, end_time, sliced_tokens))
current_segments.append(
dict(start=start_time, end=end_time, tokens=sliced_tokens)
)
last_slice = current_slice
if single_timestamp_ending:
@@ -343,17 +331,29 @@ class WhisperModel:
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
duration = last_timestamp_position * self.time_precision
current_segments.append((time_offset, time_offset + duration, tokens))
current_segments.append(
dict(start=time_offset, end=time_offset + duration, tokens=tokens)
)
seek += segment_size
if not options.condition_on_previous_text or temperature > 0.5:
prompt_reset_since = len(all_tokens)
for start, end, tokens in current_segments:
yield start, end, tokens
for segment in current_segments:
tokens = segment["tokens"]
all_tokens.extend(tokens)
text = self.decode_text_tokens(tokens)
if not text.strip():
continue
yield Segment(
start=segment["start"],
end=segment["end"],
text=text,
)
def encode_text(self, text):
return self.tokenizer.encode(text, add_special_tokens=False).ids