diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index d3d5deb..1c002ed 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -220,6 +220,8 @@ class WhisperModel: chunk_length: Optional[int] = None, clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + language_detection_threshold: Optional[float] = None, + language_detection_segments: int = 1, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -281,6 +283,9 @@ class WhisperModel: hallucination_silence_threshold: Optional[float] When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + language_detection_threshold: If the maximum probability of the language tokens is higher + than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. Returns: A tuple with: @@ -340,15 +345,51 @@ class WhisperModel: language = "en" language_probability = 1 else: - segment = features[:, : self.feature_extractor.nb_max_frames] - encoder_output = self.encode(segment) - # results is a list of tuple[str, float] with language names and - # probabilities. - results = self.model.detect_language(encoder_output)[0] - # Parse language names to strip out markers - all_language_probs = [(token[2:-2], prob) for (token, prob) in results] - # Get top language token and probability - language, language_probability = all_language_probs[0] + if ( + language_detection_segments is None + or language_detection_segments < 1 + ): + language_detection_segments = 1 + seek = 0 + detected_language_info = {} + content_frames = ( + features.shape[-1] - self.feature_extractor.nb_max_frames + ) + while ( + seek < content_frames + and seek + < self.feature_extractor.nb_max_frames * language_detection_segments + ): + segment = features[ + :, seek : seek + self.feature_extractor.nb_max_frames + ] + encoder_output = self.encode(segment) + # results is a list of tuple[str, float] with language names and + # probabilities. + results = self.model.detect_language(encoder_output)[0] + # Parse language names to strip out markers + all_language_probs = [ + (token[2:-2], prob) for (token, prob) in results + ] + # Get top language token and probability + language, language_probability = all_language_probs[0] + if ( + language_detection_threshold is None + or language_probability > language_detection_threshold + ): + break + detected_language_info.setdefault(language, []).append( + language_probability + ) + seek += segment.shape[-1] + else: + # If no language detected for all segments, the majority vote of the highest + # projected languages for all segments is used to determine the language. + language = max( + detected_language_info, + key=lambda lang: len(detected_language_info[lang]), + ) + language_probability = max(detected_language_info[language]) self.logger.info( "Detected language '%s' with probability %.2f",