Compare commits

...

10 Commits

Author SHA1 Message Date
Jong Wook Kim
c09a7ae299 Update decoding.py (#1219) 2023-04-11 15:13:13 -07:00
Fernando O. Gallego
b0022b3283 Update decoding.py (#1155)
* Update decoding.py

Following the suggestions of @Jeronymous in https://github.com/openai/whisper/pull/914 and https://github.com/openai/whisper/discussions/924, it solves the problem of endless loop.

* Removed blank line and whitespaces in empty lines.

* Suggested changes according to the linter

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-04-11 15:06:03 -07:00
Arseniy Bushyn
76c901ab8d Update README.md to reference tiktoken (#1105)
Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-04-10 17:39:17 -07:00
ryanheise
43940fc978 Implement max line width and max line count, and make word highlighting optional (#1184)
* Add highlight_words, max_line_width, max_line_count

* Refactor subtitle generator

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-04-10 17:28:35 -07:00
ryanheise
255887f219 Squash long words at window and sentence boundaries. (#1114)
* Squash long words at window and sentence boundaries.

* Formatting requirements.

* Fix squashing logic to point to correct words.

---------

Co-authored-by: Jong Wook Kim <jongwook@openai.com>
2023-04-10 17:23:53 -07:00
K.B.Dharun Krishna
a151816b6b python-publish.yml: bump actions version to fix node warning (#1211) 2023-04-10 13:54:09 -07:00
Jong Wook Kim
b5851c6c40 Update tokenizer.py (#1163) 2023-03-29 13:12:36 -07:00
Jong Wook Kim
6dea21fd7f Release 20230314 2023-03-15 00:39:19 -07:00
Jong Wook Kim
79c43e4859 abort find_alignment on empty input (#1090) 2023-03-14 12:47:58 -07:00
Guillaume Klein
5f9ac653b7 Fix truncated words list when the replacement character is decoded (#1089) 2023-03-14 09:32:41 -07:00
10 changed files with 172 additions and 55 deletions

View File

@@ -8,14 +8,14 @@ jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: actions-ecosystem/action-regex-match@v2
id: regex-match
with:
text: ${{ github.event.head_commit.message }}
regex: '^Release ([^ ]+)'
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: '3.8'
- name: Install dependencies

View File

@@ -1,5 +1,13 @@
# CHANGELOG
## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))

View File

@@ -17,7 +17,7 @@ A Transformer sequence-to-sequence model is trained on various speech processing
## Setup
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [OpenAI's tiktoken](https://github.com/openai/tiktoken) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
pip install -U openai-whisper
@@ -48,7 +48,7 @@ choco install ffmpeg
scoop install ffmpeg
```
You may need [`rust`](http://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
You may need [`rust`](http://rust-lang.org) installed as well, in case [tiktoken](https://github.com/openai/tiktoken) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
```bash
pip install setuptools-rust

View File

@@ -12,3 +12,13 @@ def test_tokenizer():
assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens)
def test_split_on_unicode():
multilingual_tokenizer = get_tokenizer(multilingual=True)
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
assert words == [" elle", " est", " l", "'", "<EFBFBD>", "é", "rit", "oire"]
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]

View File

@@ -469,7 +469,12 @@ class ApplyTimestampRules(LogitFilter):
]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
# also force each segment to have a nonzero length, to prevent infinite looping
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
timestamp_last = timestamps[-1] + 1
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning

View File

@@ -170,6 +170,9 @@ def find_alignment(
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
if len(text_tokens) == 0:
return []
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
@@ -222,17 +225,26 @@ def find_alignment(
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
# hack: ensure the first and second word is not longer than twice the median word duration.
# hack: truncate long words at the start of a window and the start of a sentence.
# a better segmentation algorithm based on VAD should be able to replace this.
word_durations = end_times - start_times
word_durations = word_durations[word_durations.nonzero()]
if len(word_durations) > 0:
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
end_times[0] = start_times[1] = boundary
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
sentence_end_marks = ".。!?"
# ensure words at sentence boundaries are not longer than twice the median word duration.
for i in range(1, len(start_times)):
if end_times[i] - start_times[i] > max_duration:
if words[i] in sentence_end_marks:
end_times[i] = start_times[i] + max_duration
elif words[i - 1] in sentence_end_marks:
start_times[i] = end_times[i] - max_duration
# ensure the first and second word is not longer than twice the median word duration.
if len(start_times) > 0 and end_times[0] - start_times[0] > max_duration:
if len(start_times) > 1 and end_times[1] - start_times[1] > max_duration:
boundary = max(end_times[1] / 2, end_times[1] - max_duration)
end_times[0] = start_times[1] = boundary
start_times[0] = max(0, end_times[0] - max_duration)
return [
@@ -324,8 +336,17 @@ def add_word_timestamps(
word_index += 1
if len(words) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]
# hack: prefer the segment-level end timestamp if the last word is too long.
# a better segmentation algorithm based on VAD should be able to replace this.
if (
segment["end"] > words[-1]["start"]
and segment["end"] + 0.5 < words[-1]["end"]
):
# adjust the word-level timestamps based on the segment-level timestamps
words[-1]["end"] = segment["end"]
else:
# adjust the segment-level timestamps based on the word-level timestamps
segment["end"] = words[-1]["end"]
segment["words"] = words

View File

@@ -6,7 +6,6 @@ from functools import cached_property, lru_cache
from typing import Dict, List, Optional, Tuple
import tiktoken
from tiktoken_ext.openai_public import gpt2
LANGUAGES = {
"en": "english",
@@ -279,17 +278,27 @@ class Tokenizer:
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
@@ -342,7 +351,7 @@ def get_encoding(name: str = "gpt2"):
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=gpt2()["pat_str"],
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)

View File

@@ -401,6 +401,9 @@ def cli():
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
@@ -433,9 +436,17 @@ def cli():
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
word_options = ["highlight_words", "max_line_count", "max_line_width"]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path)
writer(result, audio_path, writer_args)
if __name__ == "__main__":

View File

@@ -1,8 +1,9 @@
import json
import os
import re
import sys
import zlib
from typing import Callable, TextIO
from typing import Callable, Optional, TextIO
system_encoding = sys.getdefaultencoding()
@@ -73,7 +74,7 @@ class ResultWriter:
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str):
def __call__(self, result: dict, audio_path: str, options: dict):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
@@ -81,16 +82,16 @@ class ResultWriter:
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
self.write_result(result, file=f, options=options)
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
@@ -99,33 +100,81 @@ class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(self, result: dict):
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
def iterate_result(self, result: dict, options: dict):
raw_max_line_width: Optional[int] = options["max_line_width"]
max_line_count: Optional[int] = options["max_line_count"]
highlight_words: bool = options["highlight_words"]
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
preserve_segments = max_line_count is None or raw_max_line_width is None
if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings]
all_words[0] = all_words[0].strip() # remove the leading space, if any
last = segment_start
for i, this_word in enumerate(word_timings):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, segment_text
def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: list[dict] = []
last = result["segments"][0]["words"][0]["start"]
for segment in result["segments"]:
for i, original_timing in enumerate(segment["words"]):
timing = original_timing.copy()
long_pause = not preserve_segments and timing["start"] - last > 3.0
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
if len(subtitle) > 0:
yield subtitle
yield start, end, "".join(
[
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
)
last = end
if "words" in result["segments"][0]:
for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text
if last != segment_end:
yield last, segment_end, segment_text
else:
yield start, end, "".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
@@ -141,9 +190,9 @@ class WriteVTT(SubtitlesWriter):
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result):
for start, end, text in self.iterate_result(result, options):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -152,8 +201,10 @@ class WriteSRT(SubtitlesWriter):
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO):
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
def write_result(self, result: dict, file: TextIO, options: dict):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -169,7 +220,7 @@ class WriteTSV(ResultWriter):
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
@@ -180,11 +231,13 @@ class WriteTSV(ResultWriter):
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: dict):
json.dump(result, file)
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
@@ -196,9 +249,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(result: dict, file: TextIO):
def write_all(result: dict, file: TextIO, options: dict):
for writer in all_writers:
writer(result, file)
writer(result, file, options)
return write_all

View File

@@ -1 +1 @@
__version__ = "20230308"
__version__ = "20230314"