Compare commits

4 Commits
main ... prompt

Author SHA1 Message Date
62e34f0c03 format code 2023-04-20 02:02:11 +08:00
67fac2a4ce Check multiple prompts 2023-04-18 19:08:43 +08:00
6fc4e6f230 fix id start 2023-04-18 15:39:14 +08:00
55756284ac Ignore repeated prompt 2023-04-18 12:15:21 +08:00

View File

@@ -188,12 +188,15 @@ def transcribe(
input_stride * HOP_LENGTH / SAMPLE_RATE input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds) ) # time per output token: 0.02 (seconds)
all_tokens = [] all_tokens = []
all_prompts_tokens = []
all_prompts_segments = []
all_segments = [] all_segments = []
prompt_reset_since = 0 prompt_reset_since = 0
if initial_prompt is not None: if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens) all_tokens.extend(initial_prompt_tokens)
all_prompts_tokens.extend(initial_prompt_tokens)
else: else:
initial_prompt_tokens = [] initial_prompt_tokens = []
@@ -225,7 +228,7 @@ def transcribe(
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:] decode_options["prompt"] = all_prompts_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
@@ -310,7 +313,7 @@ def transcribe(
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used # do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_prompts_tokens)
if word_timestamps: if word_timestamps:
add_word_timestamps( add_word_timestamps(
@@ -357,6 +360,20 @@ def transcribe(
[token for segment in current_segments for token in segment["tokens"]] [token for segment in current_segments for token in segment["tokens"]]
) )
for i, segment in enumerate(current_segments, start=len(all_segments)):
if not segment["text"].strip():
continue
check_prompt_num = 1
if any(
[
prev["text"].strip() == segment["text"].strip()
for prev in all_prompts_segments[-check_prompt_num:]
]
):
continue
all_prompts_tokens.extend(segment["tokens"])
all_prompts_segments.append({"id": i, **segment})
# update progress bar # update progress bar
pbar.update(min(content_frames, seek) - previous_seek) pbar.update(min(content_frames, seek) - previous_seek)