Align segment structure with openai/whisper (#154)
* Align segment structure with openai/whisper * Update code to apply requested changes * Move increment below the segment filtering --------- Co-authored-by: Guillaume Klein <guillaumekln@users.noreply.github.com>
This commit is contained in:
@@ -28,12 +28,17 @@ class Word(NamedTuple):
|
||||
|
||||
|
||||
class Segment(NamedTuple):
|
||||
id: int
|
||||
seek: int
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: Optional[List[Word]]
|
||||
avg_log_prob: float
|
||||
tokens: List[int]
|
||||
temperature: float
|
||||
avg_logprob: float
|
||||
compression_ratio: float
|
||||
no_speech_prob: float
|
||||
words: Optional[List[Word]]
|
||||
|
||||
|
||||
class TranscriptionOptions(NamedTuple):
|
||||
@@ -335,6 +340,7 @@ class WhisperModel:
|
||||
encoder_output: Optional[ctranslate2.StorageView] = None,
|
||||
) -> Iterable[Segment]:
|
||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
idx = 0
|
||||
seek = 0
|
||||
all_tokens = []
|
||||
prompt_reset_since = 0
|
||||
@@ -368,9 +374,12 @@ class WhisperModel:
|
||||
if encoder_output is None:
|
||||
encoder_output = self.encode(segment)
|
||||
|
||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||
encoder_output, prompt, tokenizer, options
|
||||
)
|
||||
(
|
||||
result,
|
||||
avg_logprob,
|
||||
temperature,
|
||||
compression_ratio,
|
||||
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
|
||||
|
||||
if options.no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
@@ -378,7 +387,7 @@ class WhisperModel:
|
||||
|
||||
if (
|
||||
options.log_prob_threshold is not None
|
||||
and avg_log_prob > options.log_prob_threshold
|
||||
and avg_logprob > options.log_prob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
@@ -509,18 +518,24 @@ class WhisperModel:
|
||||
continue
|
||||
|
||||
all_tokens.extend(tokens)
|
||||
idx += 1
|
||||
|
||||
yield Segment(
|
||||
id=idx,
|
||||
seek=seek,
|
||||
start=segment["start"],
|
||||
end=segment["end"],
|
||||
text=text,
|
||||
tokens=tokens,
|
||||
temperature=temperature,
|
||||
avg_logprob=avg_logprob,
|
||||
compression_ratio=compression_ratio,
|
||||
no_speech_prob=result.no_speech_prob,
|
||||
words=(
|
||||
[Word(**word) for word in segment["words"]]
|
||||
if options.word_timestamps
|
||||
else None
|
||||
),
|
||||
avg_log_prob=avg_log_prob,
|
||||
no_speech_prob=result.no_speech_prob,
|
||||
)
|
||||
|
||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||
@@ -539,10 +554,11 @@ class WhisperModel:
|
||||
prompt: List[int],
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
||||
result = None
|
||||
avg_log_prob = None
|
||||
avg_logprob = None
|
||||
final_temperature = None
|
||||
compression_ratio = None
|
||||
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
@@ -580,8 +596,8 @@ class WhisperModel:
|
||||
|
||||
# Recover the average log prob from the returned score.
|
||||
seq_len = len(tokens)
|
||||
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
|
||||
avg_log_prob = cum_log_prob / (seq_len + 1)
|
||||
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
||||
avg_logprob = cum_logprob / (seq_len + 1)
|
||||
|
||||
text = tokenizer.decode(tokens).strip()
|
||||
compression_ratio = get_compression_ratio(text)
|
||||
@@ -603,21 +619,21 @@ class WhisperModel:
|
||||
|
||||
if (
|
||||
options.log_prob_threshold is not None
|
||||
and avg_log_prob < options.log_prob_threshold
|
||||
and avg_logprob < options.log_prob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
|
||||
self.logger.debug(
|
||||
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
||||
temperature,
|
||||
avg_log_prob,
|
||||
avg_logprob,
|
||||
options.log_prob_threshold,
|
||||
)
|
||||
|
||||
if not needs_fallback:
|
||||
break
|
||||
|
||||
return result, avg_log_prob, final_temperature
|
||||
return result, avg_logprob, final_temperature, compression_ratio
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user