Align segment structure with openai/whisper (#154)
* Align segment structure with openai/whisper * Update code to apply requested changes * Move increment below the segment filtering --------- Co-authored-by: Guillaume Klein <guillaumekln@users.noreply.github.com>
This commit is contained in:
@@ -28,12 +28,17 @@ class Word(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class Segment(NamedTuple):
|
class Segment(NamedTuple):
|
||||||
|
id: int
|
||||||
|
seek: int
|
||||||
start: float
|
start: float
|
||||||
end: float
|
end: float
|
||||||
text: str
|
text: str
|
||||||
words: Optional[List[Word]]
|
tokens: List[int]
|
||||||
avg_log_prob: float
|
temperature: float
|
||||||
|
avg_logprob: float
|
||||||
|
compression_ratio: float
|
||||||
no_speech_prob: float
|
no_speech_prob: float
|
||||||
|
words: Optional[List[Word]]
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionOptions(NamedTuple):
|
class TranscriptionOptions(NamedTuple):
|
||||||
@@ -335,6 +340,7 @@ class WhisperModel:
|
|||||||
encoder_output: Optional[ctranslate2.StorageView] = None,
|
encoder_output: Optional[ctranslate2.StorageView] = None,
|
||||||
) -> Iterable[Segment]:
|
) -> Iterable[Segment]:
|
||||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||||
|
idx = 0
|
||||||
seek = 0
|
seek = 0
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
@@ -368,9 +374,12 @@ class WhisperModel:
|
|||||||
if encoder_output is None:
|
if encoder_output is None:
|
||||||
encoder_output = self.encode(segment)
|
encoder_output = self.encode(segment)
|
||||||
|
|
||||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
(
|
||||||
encoder_output, prompt, tokenizer, options
|
result,
|
||||||
)
|
avg_logprob,
|
||||||
|
temperature,
|
||||||
|
compression_ratio,
|
||||||
|
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
|
||||||
|
|
||||||
if options.no_speech_threshold is not None:
|
if options.no_speech_threshold is not None:
|
||||||
# no voice activity check
|
# no voice activity check
|
||||||
@@ -378,7 +387,7 @@ class WhisperModel:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
options.log_prob_threshold is not None
|
options.log_prob_threshold is not None
|
||||||
and avg_log_prob > options.log_prob_threshold
|
and avg_logprob > options.log_prob_threshold
|
||||||
):
|
):
|
||||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||||
should_skip = False
|
should_skip = False
|
||||||
@@ -509,18 +518,24 @@ class WhisperModel:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
all_tokens.extend(tokens)
|
all_tokens.extend(tokens)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
yield Segment(
|
yield Segment(
|
||||||
|
id=idx,
|
||||||
|
seek=seek,
|
||||||
start=segment["start"],
|
start=segment["start"],
|
||||||
end=segment["end"],
|
end=segment["end"],
|
||||||
text=text,
|
text=text,
|
||||||
|
tokens=tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
avg_logprob=avg_logprob,
|
||||||
|
compression_ratio=compression_ratio,
|
||||||
|
no_speech_prob=result.no_speech_prob,
|
||||||
words=(
|
words=(
|
||||||
[Word(**word) for word in segment["words"]]
|
[Word(**word) for word in segment["words"]]
|
||||||
if options.word_timestamps
|
if options.word_timestamps
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
avg_log_prob=avg_log_prob,
|
|
||||||
no_speech_prob=result.no_speech_prob,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
||||||
@@ -539,10 +554,11 @@ class WhisperModel:
|
|||||||
prompt: List[int],
|
prompt: List[int],
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
options: TranscriptionOptions,
|
options: TranscriptionOptions,
|
||||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
||||||
result = None
|
result = None
|
||||||
avg_log_prob = None
|
avg_logprob = None
|
||||||
final_temperature = None
|
final_temperature = None
|
||||||
|
compression_ratio = None
|
||||||
|
|
||||||
max_initial_timestamp_index = int(
|
max_initial_timestamp_index = int(
|
||||||
round(options.max_initial_timestamp / self.time_precision)
|
round(options.max_initial_timestamp / self.time_precision)
|
||||||
@@ -580,8 +596,8 @@ class WhisperModel:
|
|||||||
|
|
||||||
# Recover the average log prob from the returned score.
|
# Recover the average log prob from the returned score.
|
||||||
seq_len = len(tokens)
|
seq_len = len(tokens)
|
||||||
cum_log_prob = result.scores[0] * (seq_len**options.length_penalty)
|
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
||||||
avg_log_prob = cum_log_prob / (seq_len + 1)
|
avg_logprob = cum_logprob / (seq_len + 1)
|
||||||
|
|
||||||
text = tokenizer.decode(tokens).strip()
|
text = tokenizer.decode(tokens).strip()
|
||||||
compression_ratio = get_compression_ratio(text)
|
compression_ratio = get_compression_ratio(text)
|
||||||
@@ -603,21 +619,21 @@ class WhisperModel:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
options.log_prob_threshold is not None
|
options.log_prob_threshold is not None
|
||||||
and avg_log_prob < options.log_prob_threshold
|
and avg_logprob < options.log_prob_threshold
|
||||||
):
|
):
|
||||||
needs_fallback = True # average log probability is too low
|
needs_fallback = True # average log probability is too low
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
||||||
temperature,
|
temperature,
|
||||||
avg_log_prob,
|
avg_logprob,
|
||||||
options.log_prob_threshold,
|
options.log_prob_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not needs_fallback:
|
if not needs_fallback:
|
||||||
break
|
break
|
||||||
|
|
||||||
return result, avg_log_prob, final_temperature
|
return result, avg_logprob, final_temperature, compression_ratio
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user