word-level timestamps in transcribe() (#869)
* word-level timestamps in `transcribe()` * moving to `timing.py` * numba implementation for dtw, replacing dtw-python * triton implementation for dtw * add test for dtw implementations * triton implementation of median_filter * a simple word-level timestamps test * add scipy as dev dependency * installs an older version of Triton if CUDA < 11.4 * fix broken merge * loosen nvcc version match regex * find_alignment() function * miscellaneous improvements * skip median filtering when the input is too small * Expose punctuation options in cli and transcribe() (#973) * fix merge error * fix merge error 2 * annotating that word_timestamps is experimental --------- Co-authored-by: ryanheise <ryan@ryanheise.com>
This commit is contained in:
14
tests/conftest.py
Normal file
14
tests/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import random as rand
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "requires_cuda")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random():
|
||||
rand.seed(42)
|
||||
numpy.random.seed(42)
|
||||
87
tests/test_timing.py
Normal file
87
tests/test_timing.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
import torch
|
||||
|
||||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter
|
||||
|
||||
|
||||
sizes = [
|
||||
(10, 20), (32, 16), (123, 1500), (234, 189),
|
||||
]
|
||||
shapes = [
|
||||
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw(N: int, M: int):
|
||||
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
|
||||
np.random.shuffle(steps)
|
||||
x = np.random.random((N, M)).astype(np.float32)
|
||||
|
||||
i, j, k = 0, 0, 0
|
||||
trace = []
|
||||
while True:
|
||||
x[i, j] -= 1
|
||||
trace.append((i, j))
|
||||
|
||||
if k == len(steps):
|
||||
break
|
||||
|
||||
if k + 1 < len(steps) and steps[k] != steps[k + 1]:
|
||||
i += 1
|
||||
j += 1
|
||||
k += 2
|
||||
continue
|
||||
|
||||
if steps[k] == 0:
|
||||
i += 1
|
||||
if steps[k] == 1:
|
||||
j += 1
|
||||
k += 1
|
||||
|
||||
trace = np.array(trace).T
|
||||
dtw_trace = dtw_cpu(x)
|
||||
|
||||
assert np.allclose(trace, dtw_trace)
|
||||
|
||||
|
||||
@pytest.mark.requires_cuda
|
||||
@pytest.mark.parametrize("N, M", sizes)
|
||||
def test_dtw_cuda_equivalence(N: int, M: int):
|
||||
x_numpy = np.random.randn(N, M).astype(np.float32)
|
||||
x_cuda = torch.from_numpy(x_numpy).cuda()
|
||||
|
||||
trace_cpu = dtw_cpu(x_numpy)
|
||||
trace_cuda = dtw_cuda(x_cuda)
|
||||
|
||||
assert np.allclose(trace_cpu, trace_cuda)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", shapes)
|
||||
def test_median_filter(shape):
|
||||
x = torch.randn(*shape)
|
||||
|
||||
for filter_width in [3, 5, 7, 13]:
|
||||
filtered = median_filter(x, filter_width)
|
||||
|
||||
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
|
||||
pad_width = filter_width // 2
|
||||
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
|
||||
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
|
||||
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]
|
||||
|
||||
assert np.allclose(filtered, scipy_filtered)
|
||||
|
||||
|
||||
@pytest.mark.requires_cuda
|
||||
@pytest.mark.parametrize("shape", shapes)
|
||||
def test_median_filter_equivalence(shape):
|
||||
x = torch.randn(*shape)
|
||||
|
||||
for filter_width in [3, 5, 7, 13]:
|
||||
filtered_cpu = median_filter(x, filter_width)
|
||||
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()
|
||||
|
||||
assert np.allclose(filtered_cpu, filtered_gpu)
|
||||
@@ -13,10 +13,22 @@ def test_transcribe(model_name: str):
|
||||
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
|
||||
|
||||
language = "en" if model_name.endswith(".en") else None
|
||||
result = model.transcribe(audio_path, language=language, temperature=0.0)
|
||||
result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
|
||||
assert result["language"] == "en"
|
||||
|
||||
transcription = result["text"].lower()
|
||||
assert "my fellow americans" in transcription
|
||||
assert "your country" in transcription
|
||||
assert "do for you" in transcription
|
||||
|
||||
timing_checked = False
|
||||
for segment in result["segments"]:
|
||||
for timing in segment["words"]:
|
||||
assert timing["start"] < timing["end"]
|
||||
if timing["word"].strip(" ,") == "Americans":
|
||||
assert timing["start"] <= 1.8
|
||||
assert timing["end"] >= 1.8
|
||||
print(timing)
|
||||
timing_checked = True
|
||||
|
||||
assert timing_checked
|
||||
|
||||
Reference in New Issue
Block a user