Always run the encoder at the beginning of the loop (#468)

This commit is contained in:
Guillaume Klein
2023-09-12 14:44:37 +02:00
committed by GitHub
parent f697945691
commit 81086f6d33

View File

@@ -417,7 +417,7 @@ class WhisperModel:
prefix=options.prefix if seek == 0 else None, prefix=options.prefix if seek == 0 else None,
) )
if encoder_output is None: if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment) encoder_output = self.encode(segment)
( (
@@ -447,7 +447,6 @@ class WhisperModel:
# fast-forward to the next segment boundary # fast-forward to the next segment boundary
seek += segment_size seek += segment_size
encoder_output = None
continue continue
tokens = result.sequences_ids[0] tokens = result.sequences_ids[0]
@@ -554,8 +553,6 @@ class WhisperModel:
if seek_shift > 0: if seek_shift > 0:
seek = previous_seek + seek_shift seek = previous_seek + seek_shift
encoder_output = None
for segment in current_segments: for segment in current_segments:
tokens = segment["tokens"] tokens = segment["tokens"]
text = tokenizer.decode(tokens) text = tokenizer.decode(tokens)