apply formatting with black (#1038)
* applying black (with the default 88-column limit) * add flake8 * add isort * fix isort
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
@@ -26,13 +26,16 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
||||
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||||
x = x[None, None, :]
|
||||
|
||||
assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
|
||||
assert (
|
||||
filter_width > 0 and filter_width % 2 == 1
|
||||
), "`filter_width` should be an odd number"
|
||||
|
||||
result = None
|
||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||
if x.is_cuda:
|
||||
try:
|
||||
from .triton_ops import median_filter_cuda
|
||||
|
||||
result = median_filter_cuda(x, filter_width)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
@@ -49,6 +52,7 @@ def median_filter(x: torch.Tensor, filter_width: int):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@numba.jit
|
||||
def backtrace(trace: np.ndarray):
|
||||
i = trace.shape[0] - 1
|
||||
@@ -106,7 +110,9 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
x_skew = (
|
||||
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
@@ -122,10 +128,12 @@ def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
trace.stride(0),
|
||||
N,
|
||||
M,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1]
|
||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
||||
:, : N + 1
|
||||
]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
@@ -181,8 +189,10 @@ def find_alignment(
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||
text_token_probs = text_token_probs.tolist()
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
@@ -196,7 +206,7 @@ def find_alignment(
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence):-1]
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
text_indices, time_indices = dtw(-matrix)
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||
@@ -207,7 +217,8 @@ def find_alignment(
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
np.mean(text_token_probs[i:j])
|
||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
# hack: ensure the first and second word is not longer than twice the median word duration.
|
||||
@@ -218,7 +229,8 @@ def find_alignment(
|
||||
median_duration = np.median(word_durations)
|
||||
max_duration = median_duration * 2
|
||||
if len(word_durations) >= 2 and word_durations[1] > max_duration:
|
||||
end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||
boundary = max(end_times[2] / 2, end_times[2] - max_duration)
|
||||
end_times[0] = start_times[1] = boundary
|
||||
if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration:
|
||||
start_times[0] = max(0, end_times[0] - max_duration)
|
||||
|
||||
@@ -271,19 +283,20 @@ def add_word_timestamps(
|
||||
tokenizer: Tokenizer,
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str = "\"\'“¿([{-",
|
||||
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
||||
**hyperparams,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**kwargs,
|
||||
):
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
text_tokens = [t for segment in segments for t in segment["tokens"]]
|
||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams)
|
||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||
token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments])
|
||||
segment_lengths = [len(s["tokens"]) for s in segments]
|
||||
token_sources = np.repeat(np.arange(len(segments)), segment_lengths)
|
||||
|
||||
for segment in segments:
|
||||
segment["words"] = []
|
||||
@@ -295,7 +308,12 @@ def add_word_timestamps(
|
||||
start = round(time_offset + timing.start, 2)
|
||||
end = round(time_offset + timing.end, 2)
|
||||
segment["words"].append(
|
||||
dict(word=timing.word, start=start, end=end, probability=timing.probability)
|
||||
dict(
|
||||
word=timing.word,
|
||||
start=start,
|
||||
end=end,
|
||||
probability=timing.probability,
|
||||
)
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
|
||||
Reference in New Issue
Block a user