apply formatting with black (#1038)

* applying black (with the default 88-column limit)

* add flake8

* add isort

* fix isort
This commit is contained in:
Jong Wook Kim
2023-03-06 18:50:37 -05:00
committed by GitHub
parent 500d0fe966
commit b80bcf610d
21 changed files with 533 additions and 227 deletions

View File

@@ -1,7 +1,7 @@
import subprocess
import warnings
from dataclasses import dataclass
from typing import List, TYPE_CHECKING
from typing import TYPE_CHECKING, List
import numba
import numpy as np
@@ -26,13 +26,16 @@ def median_filter(x: torch.Tensor, filter_width: int):
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :]
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
assert (
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
@@ -49,6 +52,7 @@ def median_filter(x: torch.Tensor, filter_width: int):
return result
@numba.jit
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1
@@ -106,7 +110,9 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
@@ -122,10 +128,12 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
@@ -181,8 +189,10 @@ def find_alignment(
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
@@ -196,7 +206,7 @@ def find_alignment(
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence):-1]
matrix = matrix[len(tokenizer.sot_sequence) : -1]
text_indices, time_indices = dtw(-matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
@@ -207,7 +217,8 @@ def find_alignment(
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
# hack: ensure the first and second word is not longer than twice the median word duration.
@@ -218,7 +229,8 @@ def find_alignment(
median_duration = np.median(word_durations)
max_duration = median_duration * 2
if len(word_durations) >= 2 and word_durations[1] > max_duration:
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
end_times[0] = start_times[1] = boundary
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
start_times[0] = max(0, end_times[0] - max_duration)
@@ -271,19 +283,20 @@ def add_word_timestamps(
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"\'“¿([{-",
append_punctuations: str = "\"\'.。,!?::”)]}、",
**hyperparams,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**kwargs,
):
if len(segments) == 0:
return
text_tokens = [t for segment in segments for t in segment["tokens"]]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
segment_lengths = [len(s["tokens"]) for s in segments]
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
for segment in segments:
segment["words"] = []
@@ -295,7 +308,12 @@ def add_word_timestamps(
start = round(time_offset + timing.start, 2)
end = round(time_offset + timing.end, 2)
segment["words"].append(
dict(word=timing.word, start=start, end=end, probability=timing.probability)
dict(
word=timing.word,
start=start,
end=end,
probability=timing.probability,
)
)
for segment in segments: