From 55756284aca2c0ed6497fd0dff2c1c8e057f4709 Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Tue, 18 Apr 2023 12:15:21 +0800 Subject: [PATCH] Ignore repeated prompt --- whisper/transcribe.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 84feb12..80f6dab 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -188,12 +188,15 @@ def transcribe( input_stride * HOP_LENGTH / SAMPLE_RATE ) # time per output token: 0.02 (seconds) all_tokens = [] + all_prompts_tokens = [] + all_prompts_segments = [] all_segments = [] prompt_reset_since = 0 if initial_prompt is not None: initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) all_tokens.extend(initial_prompt_tokens) + all_prompts_tokens.extend(initial_prompt_tokens) else: initial_prompt_tokens = [] @@ -225,7 +228,7 @@ def transcribe( segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) - decode_options["prompt"] = all_tokens[prompt_reset_since:] + decode_options["prompt"] = all_prompts_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) @@ -310,7 +313,7 @@ def transcribe( if not condition_on_previous_text or result.temperature > 0.5: # do not feed the prompt tokens if a high temperature was used - prompt_reset_since = len(all_tokens) + prompt_reset_since = len(all_prompts_tokens) if word_timestamps: add_word_timestamps( @@ -357,6 +360,14 @@ def transcribe( [token for segment in current_segments for token in segment["tokens"]] ) + for i, segment in enumerate(current_segments): + if not segment['text'].strip(): + continue + if len(all_prompts_segments) > 0 and all_prompts_segments[-1]['text'].strip() == segment['text'].strip(): + continue + all_prompts_tokens.extend(segment['tokens']) + all_prompts_segments.append({"id": i, **segment}) + # update progress bar pbar.update(min(content_frames, seek) - previous_seek)