Use a dict to represent intermediate segments
This commit is contained in:
@@ -223,20 +223,6 @@ class WhisperModel:
|
||||
return segments, audio_info
|
||||
|
||||
def generate_segments(self, features, options):
|
||||
tokenized_segments = self.generate_tokenized_segments(features, options)
|
||||
|
||||
for start, end, tokens in tokenized_segments:
|
||||
text = self.decode_text_tokens(tokens)
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
yield Segment(
|
||||
start=start,
|
||||
end=end,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def generate_tokenized_segments(self, features, options):
|
||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
seek = 0
|
||||
all_tokens = []
|
||||
@@ -321,7 +307,9 @@ class WhisperModel:
|
||||
time_offset + end_timestamp_position * self.time_precision
|
||||
)
|
||||
|
||||
current_segments.append((start_time, end_time, sliced_tokens))
|
||||
current_segments.append(
|
||||
dict(start=start_time, end=end_time, tokens=sliced_tokens)
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
if single_timestamp_ending:
|
||||
@@ -343,17 +331,29 @@ class WhisperModel:
|
||||
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
|
||||
duration = last_timestamp_position * self.time_precision
|
||||
|
||||
current_segments.append((time_offset, time_offset + duration, tokens))
|
||||
current_segments.append(
|
||||
dict(start=time_offset, end=time_offset + duration, tokens=tokens)
|
||||
)
|
||||
|
||||
seek += segment_size
|
||||
|
||||
if not options.condition_on_previous_text or temperature > 0.5:
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
for start, end, tokens in current_segments:
|
||||
yield start, end, tokens
|
||||
for segment in current_segments:
|
||||
tokens = segment["tokens"]
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
text = self.decode_text_tokens(tokens)
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
yield Segment(
|
||||
start=segment["start"],
|
||||
end=segment["end"],
|
||||
text=text,
|
||||
)
|
||||
|
||||
def encode_text(self, text):
|
||||
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user