diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index beede24..b149792 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -223,20 +223,6 @@ class WhisperModel: return segments, audio_info def generate_segments(self, features, options): - tokenized_segments = self.generate_tokenized_segments(features, options) - - for start, end, tokens in tokenized_segments: - text = self.decode_text_tokens(tokens) - if not text.strip(): - continue - - yield Segment( - start=start, - end=end, - text=text, - ) - - def generate_tokenized_segments(self, features, options): content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames seek = 0 all_tokens = [] @@ -321,7 +307,9 @@ class WhisperModel: time_offset + end_timestamp_position * self.time_precision ) - current_segments.append((start_time, end_time, sliced_tokens)) + current_segments.append( + dict(start=start_time, end=end_time, tokens=sliced_tokens) + ) last_slice = current_slice if single_timestamp_ending: @@ -343,17 +331,29 @@ class WhisperModel: last_timestamp_position = timestamps[-1] - self.timestamp_begin_id duration = last_timestamp_position * self.time_precision - current_segments.append((time_offset, time_offset + duration, tokens)) + current_segments.append( + dict(start=time_offset, end=time_offset + duration, tokens=tokens) + ) seek += segment_size if not options.condition_on_previous_text or temperature > 0.5: prompt_reset_since = len(all_tokens) - for start, end, tokens in current_segments: - yield start, end, tokens + for segment in current_segments: + tokens = segment["tokens"] all_tokens.extend(tokens) + text = self.decode_text_tokens(tokens) + if not text.strip(): + continue + + yield Segment( + start=segment["start"], + end=segment["end"], + text=text, + ) + def encode_text(self, text): return self.tokenizer.encode(text, add_special_tokens=False).ids