diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 5da7048..39d25d5 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -28,12 +28,17 @@ class Word(NamedTuple): class Segment(NamedTuple): + id: int + seek: int start: float end: float text: str - words: Optional[List[Word]] - avg_log_prob: float + tokens: List[int] + temperature: float + avg_logprob: float + compression_ratio: float no_speech_prob: float + words: Optional[List[Word]] class TranscriptionOptions(NamedTuple): @@ -335,6 +340,7 @@ class WhisperModel: encoder_output: Optional[ctranslate2.StorageView] = None, ) -> Iterable[Segment]: content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames + idx = 0 seek = 0 all_tokens = [] prompt_reset_since = 0 @@ -368,9 +374,12 @@ class WhisperModel: if encoder_output is None: encoder_output = self.encode(segment) - result, avg_log_prob, temperature = self.generate_with_fallback( - encoder_output, prompt, tokenizer, options - ) + ( + result, + avg_logprob, + temperature, + compression_ratio, + ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options) if options.no_speech_threshold is not None: # no voice activity check @@ -378,7 +387,7 @@ class WhisperModel: if ( options.log_prob_threshold is not None - and avg_log_prob > options.log_prob_threshold + and avg_logprob > options.log_prob_threshold ): # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False @@ -509,18 +518,24 @@ class WhisperModel: continue all_tokens.extend(tokens) + idx += 1 yield Segment( + id=idx, + seek=seek, start=segment["start"], end=segment["end"], text=text, + tokens=tokens, + temperature=temperature, + avg_logprob=avg_logprob, + compression_ratio=compression_ratio, + no_speech_prob=result.no_speech_prob, words=( [Word(**word) for word in segment["words"]] if options.word_timestamps else None ), - avg_log_prob=avg_log_prob, - no_speech_prob=result.no_speech_prob, ) def encode(self, features: np.ndarray) -> ctranslate2.StorageView: @@ -539,10 +554,11 @@ class WhisperModel: prompt: List[int], tokenizer: Tokenizer, options: TranscriptionOptions, - ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]: + ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: result = None - avg_log_prob = None + avg_logprob = None final_temperature = None + compression_ratio = None max_initial_timestamp_index = int( round(options.max_initial_timestamp / self.time_precision) @@ -580,8 +596,8 @@ class WhisperModel: # Recover the average log prob from the returned score. seq_len = len(tokens) - cum_log_prob = result.scores[0] * (seq_len**options.length_penalty) - avg_log_prob = cum_log_prob / (seq_len + 1) + cum_logprob = result.scores[0] * (seq_len**options.length_penalty) + avg_logprob = cum_logprob / (seq_len + 1) text = tokenizer.decode(tokens).strip() compression_ratio = get_compression_ratio(text) @@ -603,21 +619,21 @@ class WhisperModel: if ( options.log_prob_threshold is not None - and avg_log_prob < options.log_prob_threshold + and avg_logprob < options.log_prob_threshold ): needs_fallback = True # average log probability is too low self.logger.debug( "Log probability threshold is not met with temperature %.1f (%f < %f)", temperature, - avg_log_prob, + avg_logprob, options.log_prob_threshold, ) if not needs_fallback: break - return result, avg_log_prob, final_temperature + return result, avg_logprob, final_temperature, compression_ratio def get_prompt( self,