diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 32eeeb2..beede24 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -286,6 +286,14 @@ class WhisperModel: tokens = result.sequences_ids[0] + current_segments = [] + + single_timestamp_ending = ( + len(tokens) >= 2 + and tokens[-2] < self.timestamp_begin_id + and tokens[-1] >= self.timestamp_begin_id + ) + consecutive_timestamps = [ i for i in range(len(tokens)) @@ -295,17 +303,12 @@ class WhisperModel: ] if len(consecutive_timestamps) > 0: - ended_with_single_timestamp = ( - len(tokens) >= 2 - and tokens[-2] < self.timestamp_begin_id - and tokens[-1] >= self.timestamp_begin_id - ) - - if ended_with_single_timestamp: - consecutive_timestamps.append(len(tokens)) + slices = list(consecutive_timestamps) + if single_timestamp_ending: + slices.append(len(tokens)) last_slice = 0 - for i, current_slice in enumerate(consecutive_timestamps): + for current_slice in slices: sliced_tokens = tokens[last_slice:current_slice] start_timestamp_position = ( sliced_tokens[0] - self.timestamp_begin_id @@ -318,10 +321,10 @@ class WhisperModel: time_offset + end_timestamp_position * self.time_precision ) - yield start_time, end_time, sliced_tokens + current_segments.append((start_time, end_time, sliced_tokens)) last_slice = current_slice - if ended_with_single_timestamp: + if single_timestamp_ending: # single timestamp at the end means no speech after the last timestamp. seek += segment_size else: @@ -331,8 +334,6 @@ class WhisperModel: ) seek += last_timestamp_position * self.input_stride - all_tokens.extend(tokens[: last_slice + 1]) - else: duration = segment_duration timestamps = [ @@ -342,14 +343,17 @@ class WhisperModel: last_timestamp_position = timestamps[-1] - self.timestamp_begin_id duration = last_timestamp_position * self.time_precision - yield time_offset, time_offset + duration, tokens + current_segments.append((time_offset, time_offset + duration, tokens)) seek += segment_size - all_tokens.extend(tokens) if not options.condition_on_previous_text or temperature > 0.5: prompt_reset_since = len(all_tokens) + for start, end, tokens in current_segments: + yield start, end, tokens + all_tokens.extend(tokens) + def encode_text(self, text): return self.tokenizer.encode(text, add_special_tokens=False).ids