transcribe: return all language probabilities if requested (#210)
* transcribe: return all language probabilities if requested If return_all_language_probs is True, TranscriptionInfo structure will have a list of tuples reflecting all language probabilities as returned by the model. * transcribe: fix docstring * transcribe: remove return_all_lang_probs parameter
This commit is contained in:
@@ -67,6 +67,7 @@ class TranscriptionInfo(NamedTuple):
|
|||||||
language: str
|
language: str
|
||||||
language_probability: float
|
language_probability: float
|
||||||
duration: float
|
duration: float
|
||||||
|
all_language_probs: Optional[List[Tuple[str, float]]]
|
||||||
transcription_options: TranscriptionOptions
|
transcription_options: TranscriptionOptions
|
||||||
vad_options: VadOptions
|
vad_options: VadOptions
|
||||||
|
|
||||||
@@ -275,6 +276,7 @@ class WhisperModel:
|
|||||||
features = self.feature_extractor(audio)
|
features = self.feature_extractor(audio)
|
||||||
|
|
||||||
encoder_output = None
|
encoder_output = None
|
||||||
|
all_language_probs = None
|
||||||
|
|
||||||
if language is None:
|
if language is None:
|
||||||
if not self.model.is_multilingual:
|
if not self.model.is_multilingual:
|
||||||
@@ -283,9 +285,13 @@ class WhisperModel:
|
|||||||
else:
|
else:
|
||||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||||
encoder_output = self.encode(segment)
|
encoder_output = self.encode(segment)
|
||||||
results = self.model.detect_language(encoder_output)
|
# results is a list of tuple[str, float] with language names and
|
||||||
language_token, language_probability = results[0][0]
|
# probabilities.
|
||||||
language = language_token[2:-2]
|
results = self.model.detect_language(encoder_output)[0]
|
||||||
|
# Parse language names to strip out markers
|
||||||
|
all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
|
||||||
|
# Get top language token and probability
|
||||||
|
language, language_probability = all_language_probs[0]
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Detected language '%s' with probability %.2f",
|
"Detected language '%s' with probability %.2f",
|
||||||
@@ -336,6 +342,7 @@ class WhisperModel:
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
transcription_options=options,
|
transcription_options=options,
|
||||||
vad_options=vad_parameters,
|
vad_options=vad_parameters,
|
||||||
|
all_language_probs=all_language_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return segments, info
|
return segments, info
|
||||||
|
|||||||
@@ -6,11 +6,18 @@ from faster_whisper import WhisperModel, decode_audio
|
|||||||
def test_transcribe(jfk_path):
|
def test_transcribe(jfk_path):
|
||||||
model = WhisperModel("tiny")
|
model = WhisperModel("tiny")
|
||||||
segments, info = model.transcribe(jfk_path, word_timestamps=True)
|
segments, info = model.transcribe(jfk_path, word_timestamps=True)
|
||||||
|
assert info.all_language_probs is not None
|
||||||
|
|
||||||
assert info.language == "en"
|
assert info.language == "en"
|
||||||
assert info.language_probability > 0.9
|
assert info.language_probability > 0.9
|
||||||
assert info.duration == 11
|
assert info.duration == 11
|
||||||
|
|
||||||
|
# Get top language info from all results, which should match the
|
||||||
|
# already existing metadata
|
||||||
|
top_lang, top_lang_score = info.all_language_probs[0]
|
||||||
|
assert info.language == top_lang
|
||||||
|
assert abs(info.language_probability - top_lang_score) < 1e-16
|
||||||
|
|
||||||
segments = list(segments)
|
segments = list(segments)
|
||||||
|
|
||||||
assert len(segments) == 1
|
assert len(segments) == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user