Add some info and debug logs (#113)

This commit is contained in:
Guillaume Klein
2023-04-05 16:57:59 +02:00
committed by GitHub
parent 746f2698db
commit 051b3350e5
2 changed files with 70 additions and 9 deletions

View File

@@ -1,4 +1,5 @@
import itertools
import logging
import os
import zlib
@@ -11,7 +12,7 @@ import tokenizers
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.utils import download_model
from faster_whisper.utils import download_model, format_timestamp, get_logger
from faster_whisper.vad import (
SpeechTimestampsMap,
collect_chunks,
@@ -93,6 +94,8 @@ class WhisperModel:
(concurrent calls to self.model.generate() will run in parallel).
This can improve the global throughput at the cost of increased memory usage.
"""
self.logger = get_logger()
if os.path.isdir(model_size_or_path):
model_path = model_size_or_path
else:
@@ -211,17 +214,40 @@ class WhisperModel:
- a generator over transcribed segments
- an instance of AudioInfo
"""
if not isinstance(audio, np.ndarray):
audio = decode_audio(
audio, sampling_rate=self.feature_extractor.sampling_rate
)
sampling_rate = self.feature_extractor.sampling_rate
duration = audio.shape[0] / self.feature_extractor.sampling_rate
if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)
duration = audio.shape[0] / sampling_rate
self.logger.info(
"Processing audio with duration %s", format_timestamp(duration)
)
if vad_filter:
vad_parameters = {} if vad_parameters is None else vad_parameters
speech_chunks = get_speech_timestamps(audio, **vad_parameters)
audio = collect_chunks(audio, speech_chunks)
self.logger.info(
"VAD filter removed %s of audio",
format_timestamp(duration - (audio.shape[0] / sampling_rate)),
)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"VAD filter kept the following audio segments: %s",
", ".join(
"[%s -> %s]"
% (
format_timestamp(chunk["start"] / sampling_rate),
format_timestamp(chunk["end"] / sampling_rate),
)
for chunk in speech_chunks
),
)
else:
speech_chunks = None
@@ -239,6 +265,12 @@ class WhisperModel:
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
self.logger.info(
"Detected language '%s' with probability %.2f",
language,
language_probability,
)
else:
language_probability = 1
@@ -275,9 +307,7 @@ class WhisperModel:
segments = self.generate_segments(features, tokenizer, options, encoder_output)
if speech_chunks:
segments = restore_speech_timestamps(
segments, speech_chunks, self.feature_extractor.sampling_rate
)
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
audio_info = AudioInfo(
language=language,
@@ -312,6 +342,11 @@ class WhisperModel:
)
segment_duration = segment_size * self.feature_extractor.time_per_frame
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"Processing segment at %s", format_timestamp(time_offset)
)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
tokenizer,
@@ -339,6 +374,12 @@ class WhisperModel:
should_skip = False
if should_skip:
self.logger.debug(
"No speech threshold is met (%f > %f)",
result.no_speech_prob,
options.no_speech_threshold,
)
# fast-forward to the next segment boundary
seek += segment_size
continue
@@ -543,12 +584,26 @@ class WhisperModel:
):
needs_fallback = True # too repetitive
self.logger.debug(
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
temperature,
compression_ratio,
options.compression_ratio_threshold,
)
if (
options.log_prob_threshold is not None
and avg_log_prob < options.log_prob_threshold
):
needs_fallback = True # average log probability is too low
self.logger.debug(
"Log probability threshold is not met with temperature %.1f (%f < %f)",
temperature,
avg_log_prob,
options.log_prob_threshold,
)
if not needs_fallback:
break