Rename offset to seek to match the OpenAI implementation

This commit is contained in:
Guillaume Klein
2023-03-09 09:58:58 +01:00
parent 6b16b8a69c
commit 4176da0d68

View File

@@ -238,7 +238,7 @@ class WhisperModel:
def generate_tokenized_segments(self, features, options): def generate_tokenized_segments(self, features, options):
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
offset = 0 seek = 0
all_tokens = [] all_tokens = []
prompt_reset_since = 0 prompt_reset_since = 0
@@ -247,13 +247,11 @@ class WhisperModel:
initial_prompt_tokens = self.encode_text(initial_prompt) initial_prompt_tokens = self.encode_text(initial_prompt)
all_tokens.extend(initial_prompt_tokens) all_tokens.extend(initial_prompt_tokens)
while offset < content_frames: while seek < content_frames:
time_offset = offset * self.feature_extractor.time_per_frame time_offset = seek * self.feature_extractor.time_per_frame
segment = features[ segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
:, offset : offset + self.feature_extractor.nb_max_frames
]
segment_size = min( segment_size = min(
self.feature_extractor.nb_max_frames, content_frames - offset self.feature_extractor.nb_max_frames, content_frames - seek
) )
segment_duration = segment_size * self.feature_extractor.time_per_frame segment_duration = segment_size * self.feature_extractor.time_per_frame
@@ -283,7 +281,7 @@ class WhisperModel:
if should_skip: if should_skip:
# fast-forward to the next segment boundary # fast-forward to the next segment boundary
offset += segment_size seek += segment_size
continue continue
tokens = result.sequences_ids[0] tokens = result.sequences_ids[0]
@@ -325,13 +323,13 @@ class WhisperModel:
if ended_with_single_timestamp: if ended_with_single_timestamp:
# single timestamp at the end means no speech after the last timestamp. # single timestamp at the end means no speech after the last timestamp.
offset += segment_size seek += segment_size
else: else:
# otherwise, ignore the unfinished segment and seek to the last timestamp # otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_position = ( last_timestamp_position = (
tokens[last_slice - 1] - self.timestamp_begin_id tokens[last_slice - 1] - self.timestamp_begin_id
) )
offset += last_timestamp_position * self.input_stride seek += last_timestamp_position * self.input_stride
all_tokens.extend(tokens[: last_slice + 1]) all_tokens.extend(tokens[: last_slice + 1])
@@ -346,7 +344,7 @@ class WhisperModel:
yield time_offset, time_offset + duration, tokens yield time_offset, time_offset + duration, tokens
offset += segment_size seek += segment_size
all_tokens.extend(tokens) all_tokens.extend(tokens)
if not options.condition_on_previous_text or temperature > 0.5: if not options.condition_on_previous_text or temperature > 0.5: