Compare commits
10 Commits
ba88b8e1b3
...
c09a7ae299
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c09a7ae299 | ||
|
|
b0022b3283 | ||
|
|
76c901ab8d | ||
|
|
43940fc978 | ||
|
|
255887f219 | ||
|
|
a151816b6b | ||
|
|
b5851c6c40 | ||
|
|
6dea21fd7f | ||
|
|
79c43e4859 | ||
|
|
5f9ac653b7 |
4
.github/workflows/python-publish.yml
vendored
4
.github/workflows/python-publish.yml
vendored
@@ -8,14 +8,14 @@ jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-ecosystem/action-regex-match@v2
|
||||
id: regex-match
|
||||
with:
|
||||
text: ${{ github.event.head_commit.message }}
|
||||
regex: '^Release ([^ ]+)'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
- name: Install dependencies
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
# CHANGELOG
|
||||
|
||||
## [v20230314](https://github.com/openai/whisper/releases/tag/v20230314)
|
||||
|
||||
* abort find_alignment on empty input ([#1090](https://github.com/openai/whisper/pull/1090))
|
||||
* Fix truncated words list when the replacement character is decoded ([#1089](https://github.com/openai/whisper/pull/1089))
|
||||
* fix github language stats getting dominated by jupyter notebook ([#1076](https://github.com/openai/whisper/pull/1076))
|
||||
* Fix alignment between the segments and the list of words ([#1087](https://github.com/openai/whisper/pull/1087))
|
||||
* Use tiktoken ([#1044](https://github.com/openai/whisper/pull/1044))
|
||||
|
||||
## [v20230308](https://github.com/openai/whisper/releases/tag/v20230308)
|
||||
|
||||
* kwargs in decode() for convenience ([#1061](https://github.com/openai/whisper/pull/1061))
|
||||
|
||||
@@ -17,7 +17,7 @@ A Transformer sequence-to-sequence model is trained on various speech processing
|
||||
|
||||
## Setup
|
||||
|
||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
|
||||
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [OpenAI's tiktoken](https://github.com/openai/tiktoken) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
|
||||
|
||||
pip install -U openai-whisper
|
||||
|
||||
@@ -48,7 +48,7 @@ choco install ffmpeg
|
||||
scoop install ffmpeg
|
||||
```
|
||||
|
||||
You may need [`rust`](http://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
|
||||
You may need [`rust`](http://rust-lang.org) installed as well, in case [tiktoken](https://github.com/openai/tiktoken) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
|
||||
|
||||
```bash
|
||||
pip install setuptools-rust
|
||||
|
||||
@@ -12,3 +12,13 @@ def test_tokenizer():
|
||||
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
||||
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
||||
assert len(gpt2_tokens) > len(multilingual_tokens)
|
||||
|
||||
|
||||
def test_split_on_unicode():
|
||||
multilingual_tokenizer = get_tokenizer(multilingual=True)
|
||||
|
||||
tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
|
||||
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)
|
||||
|
||||
assert words == [" elle", " est", " l", "'", "<EFBFBD>", "é", "rit", "oire"]
|
||||
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
|
||||
|
||||
@@ -469,7 +469,12 @@ class ApplyTimestampRules(LogitFilter):
|
||||
]
|
||||
if timestamps.numel() > 0:
|
||||
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
|
||||
# also force each segment to have a nonzero length, to prevent infinite looping
|
||||
if last_was_timestamp and not penultimate_was_timestamp:
|
||||
timestamp_last = timestamps[-1]
|
||||
else:
|
||||
timestamp_last = timestamps[-1] + 1
|
||||
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
||||
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
|
||||
@@ -170,6 +170,9 @@ def find_alignment(
|
||||
medfilt_width: int = 7,
|
||||
qk_scale: float = 1.0,
|
||||
) -> List[WordTiming]:
|
||||
if len(text_tokens) == 0:
|
||||
return []
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
@@ -222,17 +225,26 @@ def find_alignment(
|
||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
# hack: ensure the first and second word is not longer than twice the median word duration.
|
||||
# hack: truncate long words at the start of a window and the start of a sentence.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
word_durations = end_times - start_times
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
if len(word_durations) > 0:
|
||||
median_duration = np.median(word_durations)
|
||||
max_duration = median_duration * 2
|
||||
if len(word_durations) >= 2 and word_durations[1] > max_duration:
|
||||
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||
end_times[0] = start_times[1] = boundary
|
||||
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
|
||||
sentence_end_marks = ".。!!??"
|
||||
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
||||
for i in range(1, len(start_times)):
|
||||
if end_times[i] - start_times[i] > max_duration:
|
||||
if words[i] in sentence_end_marks:
|
||||
end_times[i] = start_times[i] + max_duration
|
||||
elif words[i - 1] in sentence_end_marks:
|
||||
start_times[i] = end_times[i] - max_duration
|
||||
# ensure the first and second word is not longer than twice the median word duration.
|
||||
if len(start_times) > 0 and end_times[0] - start_times[0] > max_duration:
|
||||
if len(start_times) > 1 and end_times[1] - start_times[1] > max_duration:
|
||||
boundary = max(end_times[1] / 2, end_times[1] - max_duration)
|
||||
end_times[0] = start_times[1] = boundary
|
||||
start_times[0] = max(0, end_times[0] - max_duration)
|
||||
|
||||
return [
|
||||
@@ -324,8 +336,17 @@ def add_word_timestamps(
|
||||
word_index += 1
|
||||
|
||||
if len(words) > 0:
|
||||
# adjust the segment-level timestamps based on the word-level timestamps
|
||||
segment["start"] = words[0]["start"]
|
||||
segment["end"] = words[-1]["end"]
|
||||
# hack: prefer the segment-level end timestamp if the last word is too long.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
if (
|
||||
segment["end"] > words[-1]["start"]
|
||||
and segment["end"] + 0.5 < words[-1]["end"]
|
||||
):
|
||||
# adjust the word-level timestamps based on the segment-level timestamps
|
||||
words[-1]["end"] = segment["end"]
|
||||
else:
|
||||
# adjust the segment-level timestamps based on the word-level timestamps
|
||||
segment["end"] = words[-1]["end"]
|
||||
|
||||
segment["words"] = words
|
||||
|
||||
@@ -6,7 +6,6 @@ from functools import cached_property, lru_cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import tiktoken
|
||||
from tiktoken_ext.openai_public import gpt2
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
@@ -279,17 +278,27 @@ class Tokenizer:
|
||||
return self.split_tokens_on_spaces(tokens)
|
||||
|
||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||
decoded_full = self.decode_with_timestamps(tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
words = []
|
||||
word_tokens = []
|
||||
current_tokens = []
|
||||
unicode_offset = 0
|
||||
|
||||
for token in tokens:
|
||||
current_tokens.append(token)
|
||||
decoded = self.decode_with_timestamps(current_tokens)
|
||||
if "\ufffd" not in decoded:
|
||||
|
||||
if (
|
||||
replacement_char not in decoded
|
||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||
== replacement_char
|
||||
):
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
current_tokens = []
|
||||
unicode_offset += len(decoded)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
@@ -342,7 +351,7 @@ def get_encoding(name: str = "gpt2"):
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=gpt2()["pat_str"],
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
@@ -401,6 +401,9 @@ def cli():
|
||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
# fmt: on
|
||||
|
||||
@@ -433,9 +436,17 @@ def cli():
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
word_options = ["highlight_words", "max_line_count", "max_line_width"]
|
||||
if not args["word_timestamps"]:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path)
|
||||
writer(result, audio_path, writer_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
129
whisper/utils.py
129
whisper/utils.py
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, TextIO
|
||||
from typing import Callable, Optional, TextIO
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
@@ -73,7 +74,7 @@ class ResultWriter:
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(self, result: dict, audio_path: str):
|
||||
def __call__(self, result: dict, audio_path: str, options: dict):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
@@ -81,16 +82,16 @@ class ResultWriter:
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f)
|
||||
self.write_result(result, file=f, options=options)
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WriteTXT(ResultWriter):
|
||||
extension: str = "txt"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
@@ -99,33 +100,81 @@ class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
def iterate_result(self, result: dict):
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
def iterate_result(self, result: dict, options: dict):
|
||||
raw_max_line_width: Optional[int] = options["max_line_width"]
|
||||
max_line_count: Optional[int] = options["max_line_count"]
|
||||
highlight_words: bool = options["highlight_words"]
|
||||
max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
|
||||
preserve_segments = max_line_count is None or raw_max_line_width is None
|
||||
|
||||
if word_timings := segment.get("words", None):
|
||||
all_words = [timing["word"] for timing in word_timings]
|
||||
all_words[0] = all_words[0].strip() # remove the leading space, if any
|
||||
last = segment_start
|
||||
for i, this_word in enumerate(word_timings):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, segment_text
|
||||
def iterate_subtitles():
|
||||
line_len = 0
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: list[dict] = []
|
||||
last = result["segments"][0]["words"][0]["start"]
|
||||
for segment in result["segments"]:
|
||||
for i, original_timing in enumerate(segment["words"]):
|
||||
timing = original_timing.copy()
|
||||
long_pause = not preserve_segments and timing["start"] - last > 3.0
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if line_len > 0 and has_room and not long_pause and not seg_break:
|
||||
# line continuation
|
||||
line_len += len(timing["word"])
|
||||
else:
|
||||
# new line
|
||||
timing["word"] = timing["word"].strip()
|
||||
if (
|
||||
len(subtitle) > 0
|
||||
and max_line_count is not None
|
||||
and (long_pause or line_count >= max_line_count)
|
||||
or seg_break
|
||||
):
|
||||
# subtitle break
|
||||
yield subtitle
|
||||
subtitle = []
|
||||
line_count = 1
|
||||
elif line_len > 0:
|
||||
# line break
|
||||
line_count += 1
|
||||
timing["word"] = "\n" + timing["word"]
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
last = timing["start"]
|
||||
if len(subtitle) > 0:
|
||||
yield subtitle
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
f"<u>{word}</u>" if j == i else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
if "words" in result["segments"][0]:
|
||||
for subtitle in iterate_subtitles():
|
||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||
if highlight_words:
|
||||
last = subtitle_start
|
||||
all_words = [timing["word"] for timing in subtitle]
|
||||
for i, this_word in enumerate(subtitle):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
|
||||
if last != segment_end:
|
||||
yield last, segment_end, segment_text
|
||||
else:
|
||||
yield start, end, "".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
else:
|
||||
yield subtitle_start, subtitle_end, subtitle_text
|
||||
else:
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
@@ -141,9 +190,9 @@ class WriteVTT(SubtitlesWriter):
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
print("WEBVTT\n", file=file)
|
||||
for start, end, text in self.iterate_result(result):
|
||||
for start, end, text in self.iterate_result(result, options):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
@@ -152,8 +201,10 @@ class WriteSRT(SubtitlesWriter):
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
for i, (start, end, text) in enumerate(
|
||||
self.iterate_result(result, options), start=1
|
||||
):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
@@ -169,7 +220,7 @@ class WriteTSV(ResultWriter):
|
||||
|
||||
extension: str = "tsv"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
@@ -180,11 +231,13 @@ class WriteTSV(ResultWriter):
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(self, result: dict, file: TextIO):
|
||||
def write_result(self, result: dict, file: TextIO, options: dict):
|
||||
json.dump(result, file)
|
||||
|
||||
|
||||
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
||||
def get_writer(
|
||||
output_format: str, output_dir: str
|
||||
) -> Callable[[dict, TextIO, dict], None]:
|
||||
writers = {
|
||||
"txt": WriteTXT,
|
||||
"vtt": WriteVTT,
|
||||
@@ -196,9 +249,9 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
|
||||
def write_all(result: dict, file: TextIO):
|
||||
def write_all(result: dict, file: TextIO, options: dict):
|
||||
for writer in all_writers:
|
||||
writer(result, file)
|
||||
writer(result, file, options)
|
||||
|
||||
return write_all
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "20230308"
|
||||
__version__ = "20230314"
|
||||
|
||||
Reference in New Issue
Block a user