diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 06154f3..8f0b354 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -67,6 +67,7 @@ class TranscriptionInfo(NamedTuple): language: str language_probability: float duration: float + all_language_probs: Optional[List[Tuple[str, float]]] transcription_options: TranscriptionOptions vad_options: VadOptions @@ -275,6 +276,7 @@ class WhisperModel: features = self.feature_extractor(audio) encoder_output = None + all_language_probs = None if language is None: if not self.model.is_multilingual: @@ -283,9 +285,13 @@ class WhisperModel: else: segment = features[:, : self.feature_extractor.nb_max_frames] encoder_output = self.encode(segment) - results = self.model.detect_language(encoder_output) - language_token, language_probability = results[0][0] - language = language_token[2:-2] + # results is a list of tuple[str, float] with language names and + # probabilities. + results = self.model.detect_language(encoder_output)[0] + # Parse language names to strip out markers + all_language_probs = [(token[2:-2], prob) for (token, prob) in results] + # Get top language token and probability + language, language_probability = all_language_probs[0] self.logger.info( "Detected language '%s' with probability %.2f", @@ -336,6 +342,7 @@ class WhisperModel: duration=duration, transcription_options=options, vad_options=vad_parameters, + all_language_probs=all_language_probs, ) return segments, info diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index f1c9572..6ecf2c4 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -6,11 +6,18 @@ from faster_whisper import WhisperModel, decode_audio def test_transcribe(jfk_path): model = WhisperModel("tiny") segments, info = model.transcribe(jfk_path, word_timestamps=True) + assert info.all_language_probs is not None assert info.language == "en" assert info.language_probability > 0.9 assert info.duration == 11 + # Get top language info from all results, which should match the + # already existing metadata + top_lang, top_lang_score = info.all_language_probs[0] + assert info.language == top_lang + assert abs(info.language_probability - top_lang_score) < 1e-16 + segments = list(segments) assert len(segments) == 1