From 38f2f4d99d297c2fc09f9f5f28eaa06b017e24a3 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Wed, 8 Mar 2023 18:34:07 -0500 Subject: [PATCH] fix all_tokens handling that caused more repetitions and discrepancy in JSON (#1060) --- tests/test_transcribe.py | 1 + whisper/timing.py | 2 +- whisper/transcribe.py | 22 ++++++++++++---------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index e5d5307..96d04eb 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -17,6 +17,7 @@ def test_transcribe(model_name: str): audio_path, language=language, temperature=0.0, word_timestamps=True ) assert result["language"] == "en" + assert result["text"] == "".join([s["text"] for s in result["segments"]]) transcription = result["text"].lower() assert "my fellow americans" in transcription diff --git a/whisper/timing.py b/whisper/timing.py index 39f3872..7bc2b9a 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -290,7 +290,7 @@ def add_word_timestamps( if len(segments) == 0: return - text_tokens = [t for segment in segments for t in segment["tokens"]] + text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot] alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) merge_punctuations(alignment, prepend_punctuations, append_punctuations) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 773e636..ed6d820 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -200,14 +200,14 @@ def transcribe( def new_segment( *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult ): - text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot] + tokens = tokens.tolist() + text_tokens = [token for token in tokens if token < tokenizer.eot] return { - "id": len(all_segments), "seek": seek, "start": start, "end": end, "text": tokenizer.decode(text_tokens), - "tokens": text_tokens, + "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, @@ -245,7 +245,6 @@ def transcribe( previous_seek = seek current_segments = [] - current_tokens = [] timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] @@ -275,7 +274,6 @@ def transcribe( result=result, ) ) - current_tokens.append(sliced_tokens.tolist()) last_slice = current_slice if single_timestamp_ending: @@ -287,7 +285,6 @@ def transcribe( tokens[last_slice - 1].item() - tokenizer.timestamp_begin ) seek += last_timestamp_pos * input_stride - all_tokens.extend(tokens[: last_slice + 1].tolist()) else: duration = segment_duration timestamps = tokens[timestamp_tokens.nonzero().flatten()] @@ -309,7 +306,6 @@ def transcribe( result=result, ) ) - current_tokens.append(tokens.tolist()) seek += segment_size if not condition_on_previous_text or result.temperature > 0.5: @@ -348,11 +344,17 @@ def transcribe( segment["text"] = "" segment["tokens"] = [] segment["words"] = [] - current_tokens[i] = [] - all_segments.extend(current_segments) + all_segments.extend( + [ + {"id": i, **segment} + for i, segment in enumerate( + current_segments, start=len(all_segments) + ) + ] + ) all_tokens.extend( - [token for segment in current_tokens for token in segment] + [token for segment in current_segments for token in segment["tokens"]] ) # update progress bar