Align segment structure with openai/whisper (#154)

* Align segment structure with openai/whisper

* Update code to apply requested changes

* Move increment below the segment filtering

---------

Co-authored-by: Guillaume Klein <guillaumekln@users.noreply.github.com>
This commit is contained in:
Amar Sood
2023-04-24 09:04:42 -04:00
committed by GitHub
parent 2b51a97e61
commit f893113759

View File

@@ -28,12 +28,17 @@ class Word(NamedTuple):
class Segment(NamedTuple):
id: int
seek: int
start: float
end: float
text: str
words: Optional[List[Word]]
avg_log_prob: float
tokens: List[int]
temperature: float
avg_logprob: float
compression_ratio: float
no_speech_prob: float
words: Optional[List[Word]]
class TranscriptionOptions(NamedTuple):
@@ -335,6 +340,7 @@ class WhisperModel:
encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
idx = 0
seek = 0
all_tokens = []
prompt_reset_since = 0
@@ -368,9 +374,12 @@ class WhisperModel:
if encoder_output is None:
encoder_output = self.encode(segment)
result, avg_log_prob, temperature = self.generate_with_fallback(
encoder_output, prompt, tokenizer, options
)
(
result,
avg_logprob,
temperature,
compression_ratio,
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
if options.no_speech_threshold is not None:
# no voice activity check
@@ -378,7 +387,7 @@ class WhisperModel:
if (
options.log_prob_threshold is not None
and avg_log_prob > options.log_prob_threshold
and avg_logprob > options.log_prob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
@@ -509,18 +518,24 @@ class WhisperModel:
continue
all_tokens.extend(tokens)
idx += 1
yield Segment(
id=idx,
seek=seek,
start=segment["start"],
end=segment["end"],
text=text,
tokens=tokens,
temperature=temperature,
avg_logprob=avg_logprob,
compression_ratio=compression_ratio,
no_speech_prob=result.no_speech_prob,
words=(
[Word(**word) for word in segment["words"]]
if options.word_timestamps
else None
),
avg_log_prob=avg_log_prob,
no_speech_prob=result.no_speech_prob,
)
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
@@ -539,10 +554,11 @@ class WhisperModel:
prompt: List[int],
tokenizer: Tokenizer,
options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
result = None
avg_log_prob = None
avg_logprob = None
final_temperature = None
compression_ratio = None
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
@@ -580,8 +596,8 @@ class WhisperModel:
# Recover the average log prob from the returned score.
seq_len = len(tokens)
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
avg_log_prob = cum_log_prob / (seq_len + 1)
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
avg_logprob = cum_logprob / (seq_len + 1)
text = tokenizer.decode(tokens).strip()
compression_ratio = get_compression_ratio(text)
@@ -603,21 +619,21 @@ class WhisperModel:
if (
options.log_prob_threshold is not None
and avg_log_prob < options.log_prob_threshold
and avg_logprob < options.log_prob_threshold
):
needs_fallback = True # average log probability is too low
self.logger.debug(
"Log probability threshold is not met with temperature %.1f (%f < %f)",
temperature,
avg_log_prob,
avg_logprob,
options.log_prob_threshold,
)
if not needs_fallback:
break
return result, avg_log_prob, final_temperature
return result, avg_logprob, final_temperature, compression_ratio
def get_prompt(
self,