Do not ignore last segment ending with one timestamp

See eab8d920ed
This commit is contained in:
Guillaume Klein
2023-03-07 10:05:04 +01:00
parent 469244a57d
commit 01ef12a6a0

View File

@@ -292,6 +292,15 @@ class WhisperModel:
]
if len(consecutive_timestamps) > 0:
ended_with_single_timestamp = (
len(tokens) >= 2
and tokens[-2] < self.timestamp_begin_id
and tokens[-1] >= self.timestamp_begin_id
)
if ended_with_single_timestamp:
consecutive_timestamps.append(len(tokens))
last_slice = 0
for i, current_slice in enumerate(consecutive_timestamps):
sliced_tokens = tokens[last_slice:current_slice]
@@ -306,19 +315,19 @@ class WhisperModel:
time_offset + end_timestamp_position * self.time_precision
)
last_in_window = i + 1 == len(consecutive_timestamps)
# Include the last timestamp so that all tokens are included in a segment.
if last_in_window:
sliced_tokens.append(tokens[current_slice])
yield start_time, end_time, sliced_tokens
last_slice = current_slice
last_timestamp_position = (
tokens[last_slice - 1] - self.timestamp_begin_id
)
offset += last_timestamp_position * self.input_stride
if ended_with_single_timestamp:
# single timestamp at the end means no speech after the last timestamp.
offset += segment.shape[-1]
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_position = (
tokens[last_slice - 1] - self.timestamp_begin_id
)
offset += last_timestamp_position * self.input_stride
all_tokens.extend(tokens[: last_slice + 1])
else: