@@ -286,6 +286,14 @@ class WhisperModel:
|
|||||||
|
|
||||||
tokens = result.sequences_ids[0]
|
tokens = result.sequences_ids[0]
|
||||||
|
|
||||||
|
current_segments = []
|
||||||
|
|
||||||
|
single_timestamp_ending = (
|
||||||
|
len(tokens) >= 2
|
||||||
|
and tokens[-2] < self.timestamp_begin_id
|
||||||
|
and tokens[-1] >= self.timestamp_begin_id
|
||||||
|
)
|
||||||
|
|
||||||
consecutive_timestamps = [
|
consecutive_timestamps = [
|
||||||
i
|
i
|
||||||
for i in range(len(tokens))
|
for i in range(len(tokens))
|
||||||
@@ -295,17 +303,12 @@ class WhisperModel:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if len(consecutive_timestamps) > 0:
|
if len(consecutive_timestamps) > 0:
|
||||||
ended_with_single_timestamp = (
|
slices = list(consecutive_timestamps)
|
||||||
len(tokens) >= 2
|
if single_timestamp_ending:
|
||||||
and tokens[-2] < self.timestamp_begin_id
|
slices.append(len(tokens))
|
||||||
and tokens[-1] >= self.timestamp_begin_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if ended_with_single_timestamp:
|
|
||||||
consecutive_timestamps.append(len(tokens))
|
|
||||||
|
|
||||||
last_slice = 0
|
last_slice = 0
|
||||||
for i, current_slice in enumerate(consecutive_timestamps):
|
for current_slice in slices:
|
||||||
sliced_tokens = tokens[last_slice:current_slice]
|
sliced_tokens = tokens[last_slice:current_slice]
|
||||||
start_timestamp_position = (
|
start_timestamp_position = (
|
||||||
sliced_tokens[0] - self.timestamp_begin_id
|
sliced_tokens[0] - self.timestamp_begin_id
|
||||||
@@ -318,10 +321,10 @@ class WhisperModel:
|
|||||||
time_offset + end_timestamp_position * self.time_precision
|
time_offset + end_timestamp_position * self.time_precision
|
||||||
)
|
)
|
||||||
|
|
||||||
yield start_time, end_time, sliced_tokens
|
current_segments.append((start_time, end_time, sliced_tokens))
|
||||||
last_slice = current_slice
|
last_slice = current_slice
|
||||||
|
|
||||||
if ended_with_single_timestamp:
|
if single_timestamp_ending:
|
||||||
# single timestamp at the end means no speech after the last timestamp.
|
# single timestamp at the end means no speech after the last timestamp.
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
else:
|
else:
|
||||||
@@ -331,8 +334,6 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
seek += last_timestamp_position * self.input_stride
|
seek += last_timestamp_position * self.input_stride
|
||||||
|
|
||||||
all_tokens.extend(tokens[: last_slice + 1])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
duration = segment_duration
|
duration = segment_duration
|
||||||
timestamps = [
|
timestamps = [
|
||||||
@@ -342,14 +343,17 @@ class WhisperModel:
|
|||||||
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
|
last_timestamp_position = timestamps[-1] - self.timestamp_begin_id
|
||||||
duration = last_timestamp_position * self.time_precision
|
duration = last_timestamp_position * self.time_precision
|
||||||
|
|
||||||
yield time_offset, time_offset + duration, tokens
|
current_segments.append((time_offset, time_offset + duration, tokens))
|
||||||
|
|
||||||
seek += segment_size
|
seek += segment_size
|
||||||
all_tokens.extend(tokens)
|
|
||||||
|
|
||||||
if not options.condition_on_previous_text or temperature > 0.5:
|
if not options.condition_on_previous_text or temperature > 0.5:
|
||||||
prompt_reset_since = len(all_tokens)
|
prompt_reset_since = len(all_tokens)
|
||||||
|
|
||||||
|
for start, end, tokens in current_segments:
|
||||||
|
yield start, end, tokens
|
||||||
|
all_tokens.extend(tokens)
|
||||||
|
|
||||||
def encode_text(self, text):
|
def encode_text(self, text):
|
||||||
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user