apply formatting with black (#1038)

* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
This commit is contained in:
Jong Wook Kim
2023-03-06 18:50:37 -05:00
committed by GitHub
parent 500d0fe966
commit b80bcf610d
21 changed files with 533 additions and 227 deletions

4
.flake8 Normal file
View File

@@ -0,0 +1,4 @@
[flake8]
per-file-ignores =
*/__init__.py: F401

View File

@@ -22,4 +22,7 @@ jobs:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install .["dev"] - run: pip install .["dev"]
- run: black --check --diff -t py38 --include '(\.pyi?)$' .
- run: isort --check --diff .
- run: flake8 --ignore E203,W503,W504,E501,E731,E741 .
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'

8
pyproject.toml Normal file
View File

@@ -0,0 +1,8 @@
[tool.black]
[tool.isort]
profile = "black"
include_trailing_comma = true
line_length = 88
multi_line_output = 3

View File

@@ -2,7 +2,7 @@ import os
import sys import sys
import pkg_resources import pkg_resources
from setuptools import setup, find_packages from setuptools import find_packages, setup
def read_version(fname="whisper/version.py"): def read_version(fname="whisper/version.py"):
@@ -16,7 +16,10 @@ if sys.platform.startswith("linux"):
try: try:
import re import re
import subprocess import subprocess
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
version_line = (
subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
)
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0] major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4): if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4 # the last version supporting CUDA < 11.4
@@ -38,7 +41,8 @@ setup(
url="https://github.com/openai/whisper", url="https://github.com/openai/whisper",
license="MIT", license="MIT",
packages=find_packages(exclude=["tests*"]), packages=find_packages(exclude=["tests*"]),
install_requires=requirements + [ install_requires=requirements
+ [
str(r) str(r)
for r in pkg_resources.parse_requirements( for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
@@ -48,5 +52,5 @@ setup(
"console_scripts": ["whisper=whisper.transcribe:cli"], "console_scripts": ["whisper=whisper.transcribe:cli"],
}, },
include_package_data=True, include_package_data=True,
extras_require={"dev": ["pytest", "scipy"]}, extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
) )

View File

@@ -2,7 +2,7 @@ import os.path
import numpy as np import numpy as np
from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram
def test_audio(): def test_audio():

View File

@@ -1,7 +1,10 @@
import pytest import pytest
from whisper.normalizers import EnglishTextNormalizer from whisper.normalizers import EnglishTextNormalizer
from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer from whisper.normalizers.english import (
EnglishNumberNormalizer,
EnglishSpellingNormalizer,
)
@pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()]) @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()])

View File

@@ -1,16 +1,21 @@
import pytest
import numpy as np import numpy as np
import pytest
import scipy.ndimage import scipy.ndimage
import torch import torch
from whisper.timing import dtw_cpu, dtw_cuda, median_filter from whisper.timing import dtw_cpu, dtw_cuda, median_filter
sizes = [ sizes = [
(10, 20), (32, 16), (123, 1500), (234, 189), (10, 20),
(32, 16),
(123, 1500),
(234, 189),
] ]
shapes = [ shapes = [
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512), (10,),
(1, 15),
(4, 5, 345),
(6, 12, 240, 512),
] ]
@@ -68,8 +73,12 @@ def test_median_filter(shape):
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges. # using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2 pad_width = filter_width // 2
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect") padded_x = np.pad(
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width]) x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect"
)
scipy_filtered = scipy.ndimage.median_filter(
padded_x, [1] * (x.ndim - 1) + [filter_width]
)
scipy_filtered = scipy_filtered[..., pad_width:-pad_width] scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
assert np.allclose(filtered, scipy_filtered) assert np.allclose(filtered, scipy_filtered)

View File

@@ -13,7 +13,9 @@ def test_transcribe(model_name: str):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
language = "en" if model_name.endswith(".en") else None language = "en" if model_name.endswith(".en") else None
result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True) result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en" assert result["language"] == "en"
transcription = result["text"].lower() transcription = result["text"].lower()

View File

@@ -10,11 +10,10 @@ from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions from .model import ModelDimensions, Whisper
from .transcribe import transcribe from .transcribe import transcribe
from .version import __version__ from .version import __version__
_MODELS = { _MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
@@ -41,12 +40,11 @@ _ALIGNMENT_HEADS = {
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00", "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj', "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj', "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
} }
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
@@ -62,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target return model_bytes if in_memory else download_target
else: else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True: while True:
buffer = source.read(8192) buffer = source.read(8192)
if not buffer: if not buffer:
@@ -76,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
model_bytes = open(download_target, "rb").read() model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target return model_bytes if in_memory else download_target
@@ -86,7 +94,12 @@ def available_models() -> List[str]:
return list(_MODELS.keys()) return list(_MODELS.keys())
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper: def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
""" """
Load a Whisper ASR model Load a Whisper ASR model
@@ -111,15 +124,8 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None: if download_root is None:
download_root = os.path.join( default = os.path.join(os.path.expanduser("~"), ".cache")
os.getenv( download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
"XDG_CACHE_HOME",
os.path.join(
os.path.expanduser("~"), ".cache"
)
),
"whisper"
)
if name in _MODELS: if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory) checkpoint_file = _download(_MODELS[name], download_root, in_memory)
@@ -128,9 +134,13 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
checkpoint_file = open(name, "rb").read() if in_memory else name checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None alignment_heads = None
else: else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device) checkpoint = torch.load(fp, map_location=device)
del checkpoint_file del checkpoint_file

View File

@@ -1,4 +1,3 @@
from .transcribe import cli from .transcribe import cli
cli() cli()

View File

@@ -16,11 +16,13 @@ N_MELS = 80
HOP_LENGTH = 160 HOP_LENGTH = 160
CHUNK_LENGTH = 30 CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input N_FRAMES = exact_div(
N_SAMPLES, HOP_LENGTH
) # 3000: number of frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each) FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each) TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE):
@@ -59,7 +61,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
""" """
if torch.is_tensor(array): if torch.is_tensor(array):
if array.shape[axis] > length: if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length: if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim pad_widths = [(0, 0)] * array.ndim
@@ -89,11 +93,15 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
) )
""" """
assert n_mels == 80, f"Unsupported n_mels: {n_mels}" assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: with np.load(
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS
):
""" """
Compute the log-Mel spectrogram of Compute the log-Mel spectrogram of

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -16,7 +16,9 @@ if TYPE_CHECKING:
@torch.no_grad() @torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]: def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:
""" """
Detect the spoken language in the audio, and return them as list of strings, along with the ids Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens. of the most probable language tokens and the probability distribution over all language tokens.
@@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
""" """
if tokenizer is None: if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual) tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: if (
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id") tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
):
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
)
single = mel.ndim == 2 single = mel.ndim == 2
if single: if single:
@@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None)
@dataclass(frozen=True) @dataclass(frozen=True)
class DecodingOptions: class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None task: str = "transcribe"
# language that the audio is in; uses detected language if None
language: Optional[str] = None
# sampling-related options # sampling-related options
temperature: float = 0.0 temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0 best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0 beam_size: Optional[int] = None # number of beams in beam search, if t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
# options for ranking generations (either beams or best-of-N samples) # "alpha" in Google NMT, or None for length norm, when ranking generations
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm # to select which to return among the beams or best-of-N samples
length_penalty: Optional[float] = None
# prompt, prefix, and token suppression # text or tokens to feed as the prompt or the prefix; for more info:
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context prompt: Optional[Union[str, List[int]]] = None # for the previous context
suppress_blank: bool = True # this will suppress blank outputs prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
# list of tokens ids (or comma-separated token ids) to suppress # list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True # this will suppress blank outputs
# timestamp sampling options # timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this max_initial_timestamp: Optional[float] = 1.0
# implementation details # implementation details
fp16: bool = True # use fp16 for most of the calculation fp16: bool = True # use fp16 for most of the calculation
@@ -158,7 +170,9 @@ class PyTorchInference(Inference):
class SequenceRanker: class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]: def rank(
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
) -> List[int]:
""" """
Given a list of groups of samples and their cumulative log probabilities, Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result return the indices of the samples in each group to select as the final result
@@ -196,7 +210,9 @@ class TokenDecoder:
def reset(self): def reset(self):
"""Initialize any stateful variables for decoding a new sequence""" """Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits """Specify how to select the next token, based on the current trace and logits
Parameters Parameters
@@ -251,7 +267,9 @@ class GreedyDecoder(TokenDecoder):
self.temperature = temperature self.temperature = temperature
self.eot = eot self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if self.temperature == 0: if self.temperature == 0:
next_tokens = logits.argmax(dim=-1) next_tokens = logits.argmax(dim=-1)
else: else:
@@ -274,7 +292,13 @@ class GreedyDecoder(TokenDecoder):
class BeamSearchDecoder(TokenDecoder): class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): def __init__(
self,
beam_size: int,
eot: int,
inference: Inference,
patience: Optional[float] = None,
):
self.beam_size = beam_size self.beam_size = beam_size
self.eot = eot self.eot = eot
self.inference = inference self.inference = inference
@@ -282,12 +306,16 @@ class BeamSearchDecoder(TokenDecoder):
self.max_candidates: int = round(beam_size * self.patience) self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self): def reset(self):
self.finished_sequences = None self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0: if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
@@ -331,7 +359,9 @@ class BeamSearchDecoder(TokenDecoder):
# add newly finished sequences to self.finished_sequences # add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences) assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates: if len(previously_finished) >= self.max_candidates:
break # the candidate list is full break # the candidate list is full
@@ -339,7 +369,8 @@ class BeamSearchDecoder(TokenDecoder):
# mark as completed if all audio has enough number of samples # mark as completed if all audio has enough number of samples
completed = all( completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
) )
return tokens, completed return tokens, completed
@@ -347,7 +378,9 @@ class BeamSearchDecoder(TokenDecoder):
# collect all finished sequences, including patience, and add unfinished ones if not enough # collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu() sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences): for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished if (
len(sequences) < self.beam_size
): # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]: for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot] sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item() sequences[tuple(sequence)] = sum_logprobs[i][j].item()
@@ -355,7 +388,8 @@ class BeamSearchDecoder(TokenDecoder):
break break
tokens: List[List[Tensor]] = [ tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences [torch.tensor(seq) for seq in sequences.keys()]
for sequences in self.finished_sequences
] ]
sum_logprobs: List[List[float]] = [ sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences list(sequences.values()) for sequences in self.finished_sequences
@@ -399,7 +433,10 @@ class SuppressTokens(LogitFilter):
class ApplyTimestampRules(LogitFilter): class ApplyTimestampRules(LogitFilter):
def __init__( def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int] self,
tokenizer: Tokenizer,
sample_begin: int,
max_initial_timestamp_index: Optional[int],
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.sample_begin = sample_begin self.sample_begin = sample_begin
@@ -414,8 +451,12 @@ class ApplyTimestampRules(LogitFilter):
for k in range(tokens.shape[0]): for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :] sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()] seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin last_was_timestamp = (
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
penultimate_was_timestamp = (
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
)
if last_was_timestamp: if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp if penultimate_was_timestamp: # has to be non-timestamp
@@ -423,7 +464,9 @@ class ApplyTimestampRules(LogitFilter):
else: # cannot be normal text tokens else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)] timestamps = sampled_tokens[
sampled_tokens.ge(self.tokenizer.timestamp_begin)
]
if timestamps.numel() > 0: if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf
@@ -434,13 +477,17 @@ class ApplyTimestampRules(LogitFilter):
# apply the `max_initial_timestamp` option # apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None: if self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index last_allowed = (
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
)
logits[:, last_allowed + 1 :] = -np.inf logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp # if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1) logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]): for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
dim=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob: if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf logits[k, : self.tokenizer.timestamp_begin] = -np.inf
@@ -456,7 +503,9 @@ class DecodingTask:
self.model = model self.model = model
language = options.language or "en" language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task) tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task
)
self.tokenizer: Tokenizer = tokenizer self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options) self.options: DecodingOptions = self._verify_options(options)
@@ -496,9 +545,13 @@ class DecodingTask:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None max_initial_timestamp_index = None
if options.max_initial_timestamp: if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) max_initial_timestamp_index = round(
self.options.max_initial_timestamp / precision
)
self.logit_filters.append( self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) ApplyTimestampRules(
tokenizer, self.sample_begin, max_initial_timestamp_index
)
) )
def _verify_options(self, options: DecodingOptions) -> DecodingOptions: def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
@@ -509,7 +562,9 @@ class DecodingTask:
raise ValueError("best_of with greedy sampling (T=0) is not compatible") raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None: if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given") raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): if options.length_penalty is not None and not (
0 <= options.length_penalty <= 1
):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1") raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options return options
@@ -519,7 +574,9 @@ class DecodingTask:
if prefix := self.options.prefix: if prefix := self.options.prefix:
prefix_tokens = ( prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix self.tokenizer.encode(" " + prefix.strip())
if isinstance(prefix, str)
else prefix
) )
if self.sample_len is not None: if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len max_prefix_len = self.n_ctx // 2 - self.sample_len
@@ -528,9 +585,15 @@ class DecodingTask:
if prompt := self.options.prompt: if prompt := self.options.prompt:
prompt_tokens = ( prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt self.tokenizer.encode(" " + prompt.strip())
if isinstance(prompt, str)
else prompt
)
tokens = (
[self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
) )
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens) return tuple(tokens)
@@ -554,7 +617,7 @@ class DecodingTask:
self.tokenizer.translate, self.tokenizer.translate,
self.tokenizer.sot, self.tokenizer.sot,
self.tokenizer.sot_prev, self.tokenizer.sot_prev,
self.tokenizer.sot_lm self.tokenizer.sot_lm,
] ]
) )
if self.tokenizer.no_speech is not None: if self.tokenizer.no_speech is not None:
@@ -567,14 +630,21 @@ class DecodingTask:
if self.options.fp16: if self.options.fp16:
mel = mel.half() mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): if mel.shape[-2:] == (
self.model.dims.n_audio_ctx,
self.model.dims.n_audio_state,
):
# encoded audio features are given; skip audio encoding # encoded audio features are given; skip audio encoding
audio_features = mel audio_features = mel
else: else:
audio_features = self.model.encoder(mel) audio_features = self.model.encoder(mel)
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): if audio_features.dtype != (
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") torch.float16 if self.options.fp16 else torch.float32
):
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)
return audio_features return audio_features
@@ -583,7 +653,9 @@ class DecodingTask:
lang_probs = None lang_probs = None
if self.options.language is None or self.options.task == "lang_id": if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) lang_tokens, lang_probs = self.model.detect_language(
audio_features, self.tokenizer
)
languages = [max(probs, key=probs.get) for probs in lang_probs] languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None: if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
@@ -600,7 +672,9 @@ class DecodingTask:
for i in range(self.sample_len): for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features) logits = self.inference.logits(tokens, audio_features)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
@@ -634,8 +708,12 @@ class DecodingTask:
languages, language_probs = self._detect_language(audio_features, tokens) languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id": if self.options.task == "lang_id":
return [ return [
DecodingResult(audio_features=features, language=language, language_probs=probs) DecodingResult(
for features, language, probs in zip(audio_features, languages, language_probs) audio_features=features, language=language, language_probs=probs
)
for features, language, probs in zip(
audio_features, languages, language_probs
)
] ]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
@@ -656,7 +734,8 @@ class DecodingTask:
# get the final candidates for each group, and slice between the first sampled token and EOT # get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [ tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
for s in tokens
] ]
# select the top-ranked sample in each group # select the top-ranked sample in each group
@@ -665,9 +744,18 @@ class DecodingTask:
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs) fields = (
texts,
languages,
tokens,
audio_features,
avg_logprobs,
no_speech_probs,
)
if len(set(map(len, fields))) != 1: if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
@@ -682,12 +770,16 @@ class DecodingTask:
temperature=self.options.temperature, temperature=self.options.temperature,
compression_ratio=compression_ratio(text), compression_ratio=compression_ratio(text),
) )
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
*fields
)
] ]
@torch.no_grad() @torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]: def decode(
model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()
) -> Union[DecodingResult, List[DecodingResult]]:
""" """
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).

View File

@@ -1,16 +1,15 @@
import base64 import base64
import gzip import gzip
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict from typing import Dict, Iterable, Optional
from typing import Iterable, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor, nn
from torch import nn
from .decoding import detect_language as detect_language_function, decode as decode_function from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function from .transcribe import transcribe as transcribe_function
@@ -36,12 +35,16 @@ class LayerNorm(nn.LayerNorm):
class Linear(nn.Linear): class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return F.linear( return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
) )
class Conv1d(nn.Conv1d): class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward( return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
) )
@@ -87,7 +90,9 @@ class MultiHeadAttention(nn.Module):
wv, qk = self.qkv_attention(q, k, v, mask) wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk return self.out(wv), qk
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
n_batch, n_ctx, n_state = q.shape n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25 scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@@ -110,11 +115,15 @@ class ResidualAttentionBlock(nn.Module):
self.attn = MultiHeadAttention(n_state, n_head) self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state) self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4 n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state) self.mlp_ln = LayerNorm(n_state)
def forward( def forward(
@@ -132,7 +141,9 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module): class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__() super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
@@ -163,14 +174,19 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module): class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__() super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state) self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] [
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
) )
self.ln = LayerNorm(n_state) self.ln = LayerNorm(n_state)
@@ -185,14 +201,19 @@ class TextDecoder(nn.Module):
the encoded audio features to be attended on the encoded audio features to be attended on
""" """
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype) x = x.to(xa.dtype)
for block in self.blocks: for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x) x = self.ln(x)
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits return logits
@@ -216,13 +237,19 @@ class Whisper(nn.Module):
self.dims.n_text_layer, self.dims.n_text_layer,
) )
# use the last half layers for alignment by default; see `set_alignment_heads()` below # use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool) all_heads = torch.zeros(
all_heads[self.dims.n_text_layer // 2:] = True self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes): def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy() array = np.frombuffer(
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head) gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor): def embed_audio(self, mel: torch.Tensor):
@@ -231,7 +258,9 @@ class Whisper(nn.Module):
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features) return self.decoder(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel)) return self.decoder(tokens, self.encoder(mel))
@property @property
@@ -260,8 +289,9 @@ class Whisper(nn.Module):
hooks = [] hooks = []
def save_to_cache(module, _, output): def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: if module not in cache or output.shape[1] > self.dims.n_text_ctx:
cache[module] = output # save as-is, for the first token or cross attention # save as-is, for the first token or cross attention
cache[module] = output
else: else:
cache[module] = torch.cat([cache[module], output], dim=1).detach() cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module] return cache[module]

View File

@@ -1,2 +1,2 @@
from .basic import BasicTextNormalizer from .basic import BasicTextNormalizer as BasicTextNormalizer
from .english import EnglishTextNormalizer from .english import EnglishTextNormalizer as EnglishTextNormalizer

View File

@@ -48,13 +48,16 @@ def remove_symbols(s: str):
Replace any other markers, symbols, punctuations with a space, keeping diacritics Replace any other markers, symbols, punctuations with a space, keeping diacritics
""" """
return "".join( return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) " " if unicodedata.category(c)[0] in "MSP" else c
for c in unicodedata.normalize("NFKC", s)
) )
class BasicTextNormalizer: class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols self.clean = (
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
)
self.split_letters = split_letters self.split_letters = split_letters
def __call__(self, s: str): def __call__(self, s: str):
@@ -66,6 +69,8 @@ class BasicTextNormalizer:
if self.split_letters: if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U)) s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space s = re.sub(
r"\s+", " ", s
) # replace any successive whitespace characters with a space
return s return s

View File

@@ -84,7 +84,8 @@ class EnglishNumberNormalizer:
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
} }
self.tens_ordinal = { self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items() name.replace("y", "ieth"): (value, "th")
for name, value in self.tens.items()
} }
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
@@ -108,7 +109,10 @@ class EnglishNumberNormalizer:
self.multipliers_ordinal = { self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items() name + "th": (value, "th") for name, value in self.multipliers.items()
} }
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal} self.multipliers_suffixed = {
**self.multipliers_plural,
**self.multipliers_ordinal,
}
self.decimals = {*self.ones, *self.tens, *self.zeros} self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = { self.preceding_prefixers = {
@@ -128,7 +132,8 @@ class EnglishNumberNormalizer:
"cents": "¢", "cents": "¢",
} }
self.prefixes = set( self.prefixes = set(
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values()) list(self.preceding_prefixers.values())
+ list(self.following_prefixers.values())
) )
self.suffixers = { self.suffixers = {
"per": {"cent": "%"}, "per": {"cent": "%"},
@@ -218,7 +223,9 @@ class EnglishNumberNormalizer:
if value is None: if value is None:
value = ones value = ones
elif isinstance(value, str) or prev in self.ones: elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10: # replace the last zero with the digit if (
prev in self.tens and ones < 10
): # replace the last zero with the digit
assert value[-1] == "0" assert value[-1] == "0"
value = value[:-1] + str(ones) value = value[:-1] + str(ones)
else: else:
@@ -522,14 +529,14 @@ class EnglishTextNormalizer:
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s) s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
for pattern, replacement in self.replacers.items(): for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s) s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
s = self.standardize_numbers(s) s = self.standardize_numbers(s)
s = self.standardize_spellings(s) s = self.standardize_spellings(s)
@@ -538,6 +545,6 @@ class EnglishTextNormalizer:
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s) s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
return s return s

View File

@@ -1,7 +1,7 @@
import subprocess import subprocess
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, TYPE_CHECKING from typing import TYPE_CHECKING, List
import numba import numba
import numpy as np import numpy as np
@@ -26,13 +26,16 @@ def median_filter(x: torch.Tensor, filter_width: int):
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :] x = x[None, None, :]
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" assert (
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
result = None result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda: if x.is_cuda:
try: try:
from .triton_ops import median_filter_cuda from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width) result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError): except (RuntimeError, subprocess.CalledProcessError):
warnings.warn( warnings.warn(
@@ -49,6 +52,7 @@ def median_filter(x: torch.Tensor, filter_width: int):
return result return result
@numba.jit @numba.jit
def backtrace(trace: np.ndarray): def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1 i = trace.shape[0] - 1
@@ -106,7 +110,9 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
M, N = x.shape M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous() x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0 cost[0, 0] = 0
@@ -122,10 +128,12 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
trace.stride(0), trace.stride(0),
N, N,
M, M,
BLOCK_SIZE=BLOCK_SIZE BLOCK_SIZE=BLOCK_SIZE,
) )
trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1] trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy()) return backtrace(trace.cpu().numpy())
@@ -181,8 +189,10 @@ def find_alignment(
with torch.no_grad(): with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1) sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist() token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
@@ -196,7 +206,7 @@ def find_alignment(
weights = median_filter(weights, medfilt_width) weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0) matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence):-1] matrix = matrix[len(tokenizer.sot_sequence) : -1]
text_indices, time_indices = dtw(-matrix) text_indices, time_indices = dtw(-matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
@@ -207,7 +217,8 @@ def find_alignment(
start_times = jump_times[word_boundaries[:-1]] start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]] end_times = jump_times[word_boundaries[1:]]
word_probabilities = [ word_probabilities = [
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
] ]
# hack: ensure the first and second word is not longer than twice the median word duration. # hack: ensure the first and second word is not longer than twice the median word duration.
@@ -218,7 +229,8 @@ def find_alignment(
median_duration = np.median(word_durations) median_duration = np.median(word_durations)
max_duration = median_duration * 2 max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration: if len(word_durations) >= 2 and word_durations[1] > max_duration:
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration) boundary = max(end_times[2] / 2, end_times[2] - max_duration)
end_times[0] = start_times[1] = boundary
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration: if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
start_times[0] = max(0, end_times[0] - max_duration) start_times[0] = max(0, end_times[0] - max_duration)
@@ -271,19 +283,20 @@ def add_word_timestamps(
tokenizer: Tokenizer, tokenizer: Tokenizer,
mel: torch.Tensor, mel: torch.Tensor,
num_frames: int, num_frames: int,
prepend_punctuations: str = "\"\'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"\'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
**hyperparams, **kwargs,
): ):
if len(segments) == 0: if len(segments) == 0:
return return
text_tokens = [t for segment in segments for t in segment["tokens"]] text_tokens = [t for segment in segments for t in segment["tokens"]]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations) merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) segment_lengths = [len(s["tokens"]) for s in segments]
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
for segment in segments: for segment in segments:
segment["words"] = [] segment["words"] = []
@@ -295,7 +308,12 @@ def add_word_timestamps(
start = round(time_offset + timing.start, 2) start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2) end = round(time_offset + timing.end, 2)
segment["words"].append( segment["words"].append(
dict(word=timing.word, start=start, end=end, probability=timing.probability) dict(
word=timing.word,
start=start,
end=end,
probability=timing.probability,
)
) )
for segment in segments: for segment in segments:

View File

@@ -1,7 +1,7 @@
import os import os
import string import string
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache, cached_property from functools import cached_property, lru_cache
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
@@ -138,7 +138,9 @@ class Tokenizer:
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs) return self.tokenizer.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): def decode(
self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs
):
return self.tokenizer.decode(token_ids, **kwargs) return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str: def decode_with_timestamps(self, tokens) -> str:
@@ -154,8 +156,9 @@ class Tokenizer:
outputs.append([]) outputs.append([])
else: else:
outputs[-1].append(token) outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] return "".join(
return "".join(outputs) [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@cached_property @cached_property
def eot(self) -> int: def eot(self) -> int:
@@ -197,7 +200,7 @@ class Tokenizer:
def language_token(self) -> int: def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field""" """Returns the token id corresponding to the value of the `language` field"""
if self.language is None: if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured") raise ValueError("This tokenizer does not have language token configured")
additional_tokens = dict( additional_tokens = dict(
zip( zip(
@@ -242,8 +245,10 @@ class Tokenizer:
keeping basic punctuations like commas, periods, question marks, exclamation points, etc. keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
""" """
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer. # symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because: # In case they're multiple tokens, suppress the first token, which is safe because:
@@ -255,7 +260,10 @@ class Tokenizer:
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
for symbol in symbols + list(miscellaneous): for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: for tokens in [
self.tokenizer.encode(symbol),
self.tokenizer.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous: if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0]) result.add(tokens[0])
@@ -367,4 +375,6 @@ def get_tokenizer(
if task is not None: if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate) sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) return Tokenizer(
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
)

View File

@@ -1,17 +1,32 @@
import argparse import argparse
import os import os
import warnings import warnings
from typing import Optional, Tuple, Union, TYPE_CHECKING from typing import TYPE_CHECKING, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import tqdm import tqdm
from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer from .utils import (
exact_div,
format_timestamp,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Whisper from .model import Whisper
@@ -29,8 +44,8 @@ def transcribe(
condition_on_previous_text: bool = True, condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None, initial_prompt: Optional[str] = None,
word_timestamps: bool = False, word_timestamps: bool = False,
prepend_punctuations: str = "\"\'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"\'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options, **decode_options,
): ):
""" """
@@ -108,12 +123,16 @@ def transcribe(
decode_options["language"] = "en" decode_options["language"] = "en"
else: else:
if verbose: if verbose:
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment) _, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get) decode_options["language"] = max(probs, key=probs.get)
if verbose is not None: if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"] language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe") task: str = decode_options.get("task", "transcribe")
@@ -123,7 +142,9 @@ def transcribe(
warnings.warn("Word-level timestamps on translations may not be reliable.") warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None decode_result = None
for t in temperatures: for t in temperatures:
@@ -140,9 +161,15 @@ def transcribe(
decode_result = model.decode(segment, options) decode_result = model.decode(segment, options)
needs_fallback = False needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low needs_fallback = True # average log probability is too low
if not needs_fallback: if not needs_fallback:
@@ -186,7 +213,9 @@ def transcribe(
# show the progress bar when verbose is False (if True, transcribed text will be printed) # show the progress bar when verbose is False (if True, transcribed text will be printed)
num_frames = mel.shape[-1] num_frames = mel.shape[-1]
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: with tqdm.tqdm(
total=num_frames, unit="frames", disable=verbose is not False
) as pbar:
while seek < num_frames: while seek < num_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[:, seek:] mel_segment = mel[:, seek:]
@@ -201,7 +230,10 @@ def transcribe(
if no_speech_threshold is not None: if no_speech_threshold is not None:
# no voice activity check # no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold: if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob # don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False should_skip = False
@@ -214,22 +246,35 @@ def transcribe(
current_tokens = [] current_tokens = []
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens 0
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]: ].add_(1)
if (
len(consecutive) > 0
): # if the output contains two consecutive timestamp tokens
if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [
False,
True,
]:
consecutive = consecutive.tolist() + [len(tokens)] consecutive = consecutive.tolist() + [len(tokens)]
last_slice = 0 last_slice = 0
for current_slice in consecutive: for current_slice in consecutive:
sliced_tokens = tokens[last_slice:current_slice] sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin start_timestamp_pos = (
end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin sliced_tokens[0].item() - tokenizer.timestamp_begin
current_segments.append(new_segment( )
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision, start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision, end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens, tokens=sliced_tokens,
result=result, result=result,
)) )
)
current_tokens.append(sliced_tokens.tolist()) current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice last_slice = current_slice
@@ -238,23 +283,32 @@ def transcribe(
seek += segment_size seek += segment_size
else: else:
# otherwise, ignore the unfinished segment and seek to the last timestamp # otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist()) all_tokens.extend(tokens[: last_slice + 1].tolist())
else: else:
duration = segment_duration duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()] timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin: if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one. # no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision duration = last_timestamp_pos * time_precision
current_segments.append(new_segment( current_segments.append(
new_segment(
start=time_offset, start=time_offset,
end=time_offset + duration, end=time_offset + duration,
tokens=tokens, tokens=tokens,
result=result, result=result,
)) )
)
current_tokens.append(tokens.tolist()) current_tokens.append(tokens.tolist())
seek += segment_size seek += segment_size
@@ -272,9 +326,13 @@ def transcribe(
prepend_punctuations=prepend_punctuations, prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations, append_punctuations=append_punctuations,
) )
word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(consecutive) > 0 and len(word_end_timestamps) > 0: if len(consecutive) > 0 and len(word_end_timestamps) > 0:
seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND) seek_shift = round(
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
)
if seek_shift > 0: if seek_shift > 0:
seek = previous_seek + seek_shift seek = previous_seek + seek_shift
@@ -293,21 +351,24 @@ def transcribe(
current_tokens[i] = [] current_tokens[i] = []
all_segments.extend(current_segments) all_segments.extend(current_segments)
all_tokens.extend([token for segment in current_tokens for token in segment]) all_tokens.extend(
[token for segment in current_tokens for token in segment]
)
# update progress bar # update progress bar
pbar.update(min(num_frames, seek) - previous_seek) pbar.update(min(num_frames, seek) - previous_seek)
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments, segments=all_segments,
language=language language=language,
) )
def cli(): def cli():
from . import available_models from . import available_models
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
@@ -339,6 +400,7 @@ def cli():
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
model_name: str = args.pop("model") model_name: str = args.pop("model")
@@ -350,7 +412,9 @@ def cli():
if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None: if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en" args["language"] = "en"
temperature = args.pop("temperature") temperature = args.pop("temperature")
@@ -363,6 +427,7 @@ def cli():
torch.set_num_threads(threads) torch.set_num_threads(threads)
from . import load_model from . import load_model
model = load_model(model_name, device=device, download_root=model_dir) model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
@@ -371,5 +436,5 @@ def cli():
writer(result, audio_path) writer(result, audio_path)
if __name__ == '__main__': if __name__ == "__main__":
cli() cli()

View File

@@ -1,8 +1,7 @@
import math from functools import lru_cache
import numpy as np import numpy as np
import torch import torch
from functools import lru_cache
try: try:
import triton import triton
@@ -12,7 +11,9 @@ except ImportError:
@triton.jit @triton.jit
def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr): def dtw_kernel(
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
):
offsets = tl.arange(0, BLOCK_SIZE) offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < M mask = offsets < M
@@ -42,37 +43,53 @@ def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def median_kernel(filter_width: int): def median_kernel(filter_width: int):
@triton.jit @triton.jit
def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width def kernel(
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
): # x.shape[-1] == filter_width
row_idx = tl.program_id(0) row_idx = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE) offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < y_stride mask = offsets < y_stride
x_ptr = x + row_idx * x_stride x_ptr = x + row_idx * x_stride # noqa: F841
y_ptr = y + row_idx * y_stride y_ptr = y + row_idx * y_stride
LOAD_ALL_ROWS_HERE LOAD_ALL_ROWS_HERE # noqa: F821
BUBBLESORT_HERE BUBBLESORT_HERE # noqa: F821
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
kernel = triton.JITFunction(kernel.fn) kernel = triton.JITFunction(kernel.fn)
kernel.src = kernel.src.replace(" LOAD_ALL_ROWS_HERE", "\n".join([ kernel.src = kernel.src.replace(
" LOAD_ALL_ROWS_HERE",
"\n".join(
[
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
for i in range(filter_width) for i in range(filter_width)
])) ]
kernel.src = kernel.src.replace(" BUBBLESORT_HERE", "\n\n".join([ ),
"\n\n".join([ )
"\n".join([ kernel.src = kernel.src.replace(
" BUBBLESORT_HERE",
"\n\n".join(
[
"\n\n".join(
[
"\n".join(
[
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
f" row{j} = smaller", f" row{j} = smaller",
f" row{j + 1} = larger", f" row{j + 1} = larger",
]) ]
)
for j in range(filter_width - i - 1) for j in range(filter_width - i - 1)
]) ]
)
for i in range(filter_width // 2 + 1) for i in range(filter_width // 2 + 1)
])) ]
),
)
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
return kernel return kernel

View File

@@ -7,11 +7,14 @@ from typing import Callable, TextIO
system_encoding = sys.getdefaultencoding() system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8": if system_encoding != "utf-8":
def make_safe(string): def make_safe(string):
# replaces any character not representable using the system default encoding with an '?', # replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding) return string.encode(system_encoding, errors="replace").decode(system_encoding)
else: else:
def make_safe(string): def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string return string
@@ -43,7 +46,9 @@ def compression_ratio(text) -> float:
return len(text_bytes) / len(zlib.compress(text_bytes)) return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected" assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0) milliseconds = round(seconds * 1000.0)
@@ -57,7 +62,9 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
milliseconds -= seconds * 1_000 milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class ResultWriter: class ResultWriter:
@@ -68,7 +75,9 @@ class ResultWriter:
def __call__(self, result: dict, audio_path: str): def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path) audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f) self.write_result(result, file=f)
@@ -82,7 +91,7 @@ class WriteTXT(ResultWriter):
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]: for segment in result["segments"]:
print(segment['text'].strip(), file=file, flush=True) print(segment["text"].strip(), file=file, flush=True)
class SubtitlesWriter(ResultWriter): class SubtitlesWriter(ResultWriter):
@@ -93,7 +102,7 @@ class SubtitlesWriter(ResultWriter):
for segment in result["segments"]: for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"]) segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"]) segment_end = self.format_timestamp(segment["end"])
segment_text = segment['text'].strip().replace('-->', '->') segment_text = segment["text"].strip().replace("-->", "->")
if word_timings := segment.get("words", None): if word_timings := segment.get("words", None):
all_words = [timing["word"] for timing in word_timings] all_words = [timing["word"] for timing in word_timings]
@@ -106,7 +115,10 @@ class SubtitlesWriter(ResultWriter):
yield last, start, segment_text yield last, start, segment_text
yield start, end, "".join( yield start, end, "".join(
[f"<u>{word}</u>" if j == i else word for j, word in enumerate(all_words)] [
f"<u>{word}</u>" if j == i else word
for j, word in enumerate(all_words)
]
) )
last = end last = end
@@ -126,7 +138,7 @@ class SubtitlesWriter(ResultWriter):
class WriteVTT(SubtitlesWriter): class WriteVTT(SubtitlesWriter):
extension: str = "vtt" extension: str = "vtt"
always_include_hours: bool = False always_include_hours: bool = False
decimal_marker: str = '.' decimal_marker: str = "."
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
@@ -137,7 +149,7 @@ class WriteVTT(SubtitlesWriter):
class WriteSRT(SubtitlesWriter): class WriteSRT(SubtitlesWriter):
extension: str = "srt" extension: str = "srt"
always_include_hours: bool = True always_include_hours: bool = True
decimal_marker: str = ',' decimal_marker: str = ","
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO):
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
@@ -153,14 +165,15 @@ class WriteTSV(ResultWriter):
an environment setting a language encoding that causes the decimal in a floating point number an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
""" """
extension: str = "tsv" extension: str = "tsv"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file) print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]: for segment in result["segments"]:
print(round(1000 * segment['start']), file=file, end="\t") print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment['end']), file=file, end="\t") print(round(1000 * segment["end"]), file=file, end="\t")
print(segment['text'].strip().replace("\t", " "), file=file, flush=True) print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter): class WriteJSON(ResultWriter):
@@ -189,4 +202,3 @@ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO],
return write_all return write_all
return writers[output_format](output_dir) return writers[output_format](output_dir)