Fix all_tokens handling

See 38f2f4d99d
This commit is contained in:
Guillaume Klein
2023-03-09 10:02:25 +01:00
parent 4176da0d68
commit 6a84df400f

View File

@@ -286,6 +286,14 @@ class WhisperModel:
tokens = result.sequences_ids[0] tokens = result.sequences_ids[0]
current_segments = []
single_timestamp_ending = (
len(tokens) >= 2
and tokens[-2] < self.timestamp_begin_id
and tokens[-1] >= self.timestamp_begin_id
)
consecutive_timestamps = [ consecutive_timestamps = [
i i
for i in range(len(tokens)) for i in range(len(tokens))
@@ -295,17 +303,12 @@ class WhisperModel:
] ]
if len(consecutive_timestamps) > 0: if len(consecutive_timestamps) > 0:
ended_with_single_timestamp = ( slices = list(consecutive_timestamps)
len(tokens) >= 2 if single_timestamp_ending:
and tokens[-2] < self.timestamp_begin_id slices.append(len(tokens))
and tokens[-1] >= self.timestamp_begin_id
)
if ended_with_single_timestamp:
consecutive_timestamps.append(len(tokens))
last_slice = 0 last_slice = 0
for i, current_slice in enumerate(consecutive_timestamps): for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice] sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_position = ( start_timestamp_position = (
sliced_tokens[0] - self.timestamp_begin_id sliced_tokens[0] - self.timestamp_begin_id
@@ -318,10 +321,10 @@ class WhisperModel:
time_offset + end_timestamp_position * self.time_precision time_offset + end_timestamp_position * self.time_precision
) )
yield start_time, end_time, sliced_tokens current_segments.append((start_time, end_time, sliced_tokens))
last_slice = current_slice last_slice = current_slice
if ended_with_single_timestamp: if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp. # single timestamp at the end means no speech after the last timestamp.
seek += segment_size seek += segment_size
else: else:
@@ -331,8 +334,6 @@ class WhisperModel:
) )
seek += last_timestamp_position * self.input_stride seek += last_timestamp_position * self.input_stride
all_tokens.extend(tokens[: last_slice + 1])
else: else:
duration = segment_duration duration = segment_duration
timestamps = [ timestamps = [
@@ -342,14 +343,17 @@ class WhisperModel:
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
duration = last_timestamp_position * self.time_precision duration = last_timestamp_position * self.time_precision
yield time_offset, time_offset + duration, tokens current_segments.append((time_offset, time_offset + duration, tokens))
seek += segment_size seek += segment_size
all_tokens.extend(tokens)
if not options.condition_on_previous_text or temperature > 0.5: if not options.condition_on_previous_text or temperature > 0.5:
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_tokens)
for start, end, tokens in current_segments:
yield start, end, tokens
all_tokens.extend(tokens)
def encode_text(self, text): def encode_text(self, text):
return self.tokenizer.encode(text, add_special_tokens=False).ids return self.tokenizer.encode(text, add_special_tokens=False).ids