Align segment structure with openai/whisper (#154)

* Align segment structure with openai/whisper

* Update code to apply requested changes

* Move increment below the segment filtering

---------

Co-authored-by: Guillaume Klein <guillaumekln@users.noreply.github.com>
This commit is contained in:
Amar Sood
2023-04-24 09:04:42 -04:00
committed by GitHub
parent 2b51a97e61
commit f893113759

View File

@@ -28,12 +28,17 @@ class Word(NamedTuple):
class Segment(NamedTuple): class Segment(NamedTuple):
id: int
seek: int
start: float start: float
end: float end: float
text: str text: str
words: Optional[List[Word]] tokens: List[int]
avg_log_prob: float temperature: float
avg_logprob: float
compression_ratio: float
no_speech_prob: float no_speech_prob: float
words: Optional[List[Word]]
class TranscriptionOptions(NamedTuple): class TranscriptionOptions(NamedTuple):
@@ -335,6 +340,7 @@ class WhisperModel:
encoder_output: Optional[ctranslate2.StorageView] = None, encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]: ) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
idx = 0
seek = 0 seek = 0
all_tokens = [] all_tokens = []
prompt_reset_since = 0 prompt_reset_since = 0
@@ -368,9 +374,12 @@ class WhisperModel:
if encoder_output is None: if encoder_output is None:
encoder_output = self.encode(segment) encoder_output = self.encode(segment)
result, avg_log_prob, temperature = self.generate_with_fallback( (
encoder_output, prompt, tokenizer, options result,
) avg_logprob,
temperature,
compression_ratio,
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
if options.no_speech_threshold is not None: if options.no_speech_threshold is not None:
# no voice activity check # no voice activity check
@@ -378,7 +387,7 @@ class WhisperModel:
if ( if (
options.log_prob_threshold is not None options.log_prob_threshold is not None
and avg_log_prob > options.log_prob_threshold and avg_logprob > options.log_prob_threshold
): ):
# don't skip if the logprob is high enough, despite the no_speech_prob # don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False should_skip = False
@@ -509,18 +518,24 @@ class WhisperModel:
continue continue
all_tokens.extend(tokens) all_tokens.extend(tokens)
idx += 1
yield Segment( yield Segment(
id=idx,
seek=seek,
start=segment["start"], start=segment["start"],
end=segment["end"], end=segment["end"],
text=text, text=text,
tokens=tokens,
temperature=temperature,
avg_logprob=avg_logprob,
compression_ratio=compression_ratio,
no_speech_prob=result.no_speech_prob,
words=( words=(
[Word(**word) for word in segment["words"]] [Word(**word) for word in segment["words"]]
if options.word_timestamps if options.word_timestamps
else None else None
), ),
avg_log_prob=avg_log_prob,
no_speech_prob=result.no_speech_prob,
) )
def encode(self, features: np.ndarray) -> ctranslate2.StorageView: def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
@@ -539,10 +554,11 @@ class WhisperModel:
prompt: List[int], prompt: List[int],
tokenizer: Tokenizer, tokenizer: Tokenizer,
options: TranscriptionOptions, options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]: ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
result = None result = None
avg_log_prob = None avg_logprob = None
final_temperature = None final_temperature = None
compression_ratio = None
max_initial_timestamp_index = int( max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision) round(options.max_initial_timestamp / self.time_precision)
@@ -580,8 +596,8 @@ class WhisperModel:
# Recover the average log prob from the returned score. # Recover the average log prob from the returned score.
seq_len = len(tokens) seq_len = len(tokens)
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty) cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
avg_log_prob = cum_log_prob / (seq_len + 1) avg_logprob = cum_logprob / (seq_len + 1)
text = tokenizer.decode(tokens).strip() text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text) compression_ratio = get_compression_ratio(text)
@@ -603,21 +619,21 @@ class WhisperModel:
if ( if (
options.log_prob_threshold is not None options.log_prob_threshold is not None
and avg_log_prob < options.log_prob_threshold and avg_logprob < options.log_prob_threshold
): ):
needs_fallback = True # average log probability is too low needs_fallback = True # average log probability is too low
self.logger.debug( self.logger.debug(
"Log probability threshold is not met with temperature %.1f (%f < %f)", "Log probability threshold is not met with temperature %.1f (%f < %f)",
temperature, temperature,
avg_log_prob, avg_logprob,
options.log_prob_threshold, options.log_prob_threshold,
) )
if not needs_fallback: if not needs_fallback:
break break
return result, avg_log_prob, final_temperature return result, avg_logprob, final_temperature, compression_ratio
def get_prompt( def get_prompt(
self, self,