Use a dict to represent intermediate segments
This commit is contained in:
@@ -223,20 +223,6 @@ class WhisperModel:
|
|||||||
return segments, audio_info
|
return segments, audio_info
|
||||||
|
|
||||||
def generate_segments(self, features, options):
|
def generate_segments(self, features, options):
|
||||||
tokenized_segments = self.generate_tokenized_segments(features, options)
|
|
||||||
|
|
||||||
for start, end, tokens in tokenized_segments:
|
|
||||||
text = self.decode_text_tokens(tokens)
|
|
||||||
if not text.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield Segment(
|
|
||||||
start=start,
|
|
||||||
end=end,
|
|
||||||
text=text,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_tokenized_segments(self, features, options):
|
|
||||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||||
seek = 0
|
seek = 0
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
@@ -321,7 +307,9 @@ class WhisperModel:
|
|||||||
time_offset + end_timestamp_position * self.time_precision
|
time_offset + end_timestamp_position * self.time_precision
|
||||||
)
|
)
|
||||||
|
|
||||||
current_segments.append((start_time, end_time, sliced_tokens))
|
current_segments.append(
|
||||||
|
dict(start=start_time, end=end_time, tokens=sliced_tokens)
|
||||||
|
)
|
||||||
last_slice = current_slice
|
last_slice = current_slice
|
||||||
|
|
||||||
if single_timestamp_ending:
|
if single_timestamp_ending:
|
||||||
@@ -343,17 +331,29 @@ class WhisperModel:
|
|||||||
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
|
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
|
||||||
duration = last_timestamp_position * self.time_precision
|
duration = last_timestamp_position * self.time_precision
|
||||||
|
|
||||||
current_segments.append((time_offset, time_offset + duration, tokens))
|
current_segments.append(
|
||||||
|
dict(start=time_offset, end=time_offset + duration, tokens=tokens)
|
||||||
|
)
|
||||||
|
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
|
|
||||||
if not options.condition_on_previous_text or temperature > 0.5:
|
if not options.condition_on_previous_text or temperature > 0.5:
|
||||||
prompt_reset_since = len(all_tokens)
|
prompt_reset_since = len(all_tokens)
|
||||||
|
|
||||||
for start, end, tokens in current_segments:
|
for segment in current_segments:
|
||||||
yield start, end, tokens
|
tokens = segment["tokens"]
|
||||||
all_tokens.extend(tokens)
|
all_tokens.extend(tokens)
|
||||||
|
|
||||||
|
text = self.decode_text_tokens(tokens)
|
||||||
|
if not text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield Segment(
|
||||||
|
start=segment["start"],
|
||||||
|
end=segment["end"],
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
def encode_text(self, text):
|
def encode_text(self, text):
|
||||||
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user