Compare commits
32 Commits
master
...
6e42088656
| Author | SHA1 | Date | |
|---|---|---|---|
|
6e42088656
|
|||
|
|
d57c5b40b0 | ||
|
|
83a368e98a | ||
|
|
eb8390233c | ||
|
4a59bb011d
|
|||
|
|
fbcf58bf98 | ||
|
|
1195359984 | ||
|
|
c22db5125d | ||
|
|
8862bee1f8 | ||
|
|
8d400e9870 | ||
|
|
bced5f04c0 | ||
|
|
65551c081f | ||
|
|
f53be1e811 | ||
|
|
4acdb5c619 | ||
|
|
a1c3583c96 | ||
|
|
2036d12634 | ||
|
|
2f6913efc8 | ||
|
|
e11d58599d | ||
|
|
49a80eb8a8 | ||
|
|
8d5e6d56d9 | ||
|
|
6eec07739e | ||
|
|
847fec4492 | ||
|
|
46080e584e | ||
|
|
3d1de60ef3 | ||
|
4ee1d54c14
|
|||
|
e50d82c18c
|
|||
|
4b64ef1f70
|
|||
|
d04e685ca2
|
|||
|
b835bdaaf1
|
|||
|
9f24e2c735
|
|||
|
9a646b69e6
|
|||
|
49af9564ab
|
@@ -1,3 +1,4 @@
|
|||||||
include faster_whisper/assets/silero_vad.onnx
|
include faster_whisper/assets/silero_vad.onnx
|
||||||
include requirements.txt
|
include requirements.txt
|
||||||
include requirements.conversion.txt
|
include requirements.conversion.txt
|
||||||
|
include faster_whisper/assets/pyannote_vad_model.bin
|
||||||
|
|||||||
48
README.md
48
README.md
@@ -69,34 +69,40 @@ segments, info = model.transcribe("audio.mp3", beam_size=5, language="en")
|
|||||||
|
|
||||||
* Python 3.8 or greater
|
* Python 3.8 or greater
|
||||||
|
|
||||||
Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package.
|
|
||||||
|
|
||||||
### GPU
|
### GPU
|
||||||
|
|
||||||
GPU execution requires the following NVIDIA libraries to be installed:
|
GPU execution requires the following NVIDIA libraries to be installed:
|
||||||
|
|
||||||
* [cuBLAS for CUDA 11](https://developer.nvidia.com/cublas)
|
* [cuBLAS for CUDA 12](https://developer.nvidia.com/cublas)
|
||||||
* [cuDNN 8 for CUDA 11](https://developer.nvidia.com/cudnn)
|
* [cuDNN 8 for CUDA 12](https://developer.nvidia.com/cudnn)
|
||||||
|
|
||||||
There are multiple ways to install these libraries. The recommended way is described in the official NVIDIA documentation, but we also suggest other installation methods below.
|
**Note**: Latest versions of `ctranslate2` support CUDA 12 only. For CUDA 11, the current workaround is downgrading to the `3.24.0` version of `ctranslate2` (This can be done with `pip install --force-reinstall ctranslate2==3.24.0` or specifying the version in a `requirements.txt`).
|
||||||
|
|
||||||
|
There are multiple ways to install the NVIDIA libraries mentioned above. The recommended way is described in the official NVIDIA documentation, but we also suggest other installation methods below.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Other installation methods (click to expand)</summary>
|
<summary>Other installation methods (click to expand)</summary>
|
||||||
|
|
||||||
|
|
||||||
|
**Note:** For all these methods below, keep in mind the above note regarding CUDA versions. Depending on your setup, you may need to install the _CUDA 11_ versions of libraries that correspond to the CUDA 12 libraries listed in the instructions below.
|
||||||
|
|
||||||
#### Use Docker
|
#### Use Docker
|
||||||
|
|
||||||
The libraries are installed in this official NVIDIA Docker image: `nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04`.
|
The libraries (cuBLAS, cuDNN) are installed in these official NVIDIA CUDA Docker images: `nvidia/cuda:12.0.0-runtime-ubuntu20.04` or `nvidia/cuda:12.0.0-runtime-ubuntu22.04`.
|
||||||
|
|
||||||
#### Install with `pip` (Linux only)
|
#### Install with `pip` (Linux only)
|
||||||
|
|
||||||
On Linux these libraries can be installed with `pip`. Note that `LD_LIBRARY_PATH` must be set before launching Python.
|
On Linux these libraries can be installed with `pip`. Note that `LD_LIBRARY_PATH` must be set before launching Python.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install nvidia-cublas-cu11 nvidia-cudnn-cu11
|
pip install nvidia-cublas-cu12 nvidia-cudnn-cu12
|
||||||
|
|
||||||
export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
|
export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Note**: Version 9+ of `nvidia-cudnn-cu12` appears to cause issues due its reliance on cuDNN 9 (Faster-Whisper does not currently support cuDNN 9). Ensure your version of the Python package is for cuDNN 8.
|
||||||
|
|
||||||
#### Download the libraries from Purfview's repository (Windows & Linux)
|
#### Download the libraries from Purfview's repository (Windows & Linux)
|
||||||
|
|
||||||
Purfview's [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) provides the required NVIDIA libraries for Windows & Linux in a [single archive](https://github.com/Purfview/whisper-standalone-win/releases/tag/libs). Decompress the archive and place the libraries in a directory included in the `PATH`.
|
Purfview's [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) provides the required NVIDIA libraries for Windows & Linux in a [single archive](https://github.com/Purfview/whisper-standalone-win/releases/tag/libs). Decompress the archive and place the libraries in a directory included in the `PATH`.
|
||||||
@@ -159,6 +165,35 @@ for segment in segments:
|
|||||||
segments, _ = model.transcribe("audio.mp3")
|
segments, _ = model.transcribe("audio.mp3")
|
||||||
segments = list(segments) # The transcription will actually run here.
|
segments = list(segments) # The transcription will actually run here.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### multi-segment language detection
|
||||||
|
|
||||||
|
To directly use the model for improved language detection, the following code snippet can be used:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
model = WhisperModel("medium", device="cuda", compute_type="float16")
|
||||||
|
language_info = model.detect_language_multi_segment("audio.mp3")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Batched faster-whisper
|
||||||
|
|
||||||
|
|
||||||
|
The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-2 Clause license and integrates its VAD model to this library. We modify this implementation and also replaced the feature extraction with a faster torch-based implementation. Batched version improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference.
|
||||||
|
|
||||||
|
The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from faster_whisper import WhisperModel, BatchedInferencePipeline
|
||||||
|
|
||||||
|
model = WhisperModel("medium", device="cuda", compute_type="float16")
|
||||||
|
batched_model = BatchedInferencePipeline(model=model)
|
||||||
|
segments, info = batched_model.transcribe("audio.mp3", batch_size=16)
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||||
|
```
|
||||||
|
|
||||||
### Faster Distil-Whisper
|
### Faster Distil-Whisper
|
||||||
|
|
||||||
The Distil-Whisper checkpoints are compatible with the Faster-Whisper package. In particular, the latest [distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3)
|
The Distil-Whisper checkpoints are compatible with the Faster-Whisper package. In particular, the latest [distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3)
|
||||||
@@ -227,6 +262,7 @@ See more model and transcription options in the [`WhisperModel`](https://github.
|
|||||||
Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list!
|
Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list!
|
||||||
|
|
||||||
|
|
||||||
|
* [faster-whisper-server](https://github.com/fedirz/faster-whisper-server) is an OpenAI compatible server using `faster-whisper`. It's easily deployable with Docker, works with OpenAI SDKs/CLI, supports streaming, and live transcription.
|
||||||
* [WhisperX](https://github.com/m-bain/whisperX) is an award-winning Python library that offers speaker diarization and accurate word-level timestamps using wav2vec2 alignment
|
* [WhisperX](https://github.com/m-bain/whisperX) is an award-winning Python library that offers speaker diarization and accurate word-level timestamps using wav2vec2 alignment
|
||||||
* [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper.
|
* [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper.
|
||||||
* [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo.
|
* [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo.
|
||||||
|
|||||||
BIN
benchmark/benchmark.m4a
Normal file
BIN
benchmark/benchmark.m4a
Normal file
Binary file not shown.
94
benchmark/memory_benchmark.py
Normal file
94
benchmark/memory_benchmark.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import py3nvml.py3nvml as nvml
|
||||||
|
|
||||||
|
from memory_profiler import memory_usage
|
||||||
|
from utils import MyThread, get_logger, inference
|
||||||
|
|
||||||
|
logger = get_logger("faster-whisper")
|
||||||
|
parser = argparse.ArgumentParser(description="Memory benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu_memory", action="store_true", help="Measure GPU memory usage"
|
||||||
|
)
|
||||||
|
parser.add_argument("--device-index", type=int, default=0, help="GPU device index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--interval",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Interval at which measurements are collected",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
device_idx = args.device_index
|
||||||
|
interval = args.interval
|
||||||
|
|
||||||
|
|
||||||
|
def measure_memory(func: Callable[[], None]):
|
||||||
|
if args.gpu_memory:
|
||||||
|
logger.info(
|
||||||
|
"Measuring maximum GPU memory usage on GPU device."
|
||||||
|
" Make sure to not have additional processes running on the same GPU."
|
||||||
|
)
|
||||||
|
# init nvml
|
||||||
|
nvml.nvmlInit()
|
||||||
|
handle = nvml.nvmlDeviceGetHandleByIndex(device_idx)
|
||||||
|
gpu_name = nvml.nvmlDeviceGetName(handle)
|
||||||
|
gpu_memory_limit = nvml.nvmlDeviceGetMemoryInfo(handle).total >> 20
|
||||||
|
gpu_power_limit = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000.0
|
||||||
|
info = {"gpu_memory_usage": [], "gpu_power_usage": []}
|
||||||
|
|
||||||
|
def _get_gpu_info():
|
||||||
|
while True:
|
||||||
|
info["gpu_memory_usage"].append(
|
||||||
|
nvml.nvmlDeviceGetMemoryInfo(handle).used >> 20
|
||||||
|
)
|
||||||
|
info["gpu_power_usage"].append(
|
||||||
|
nvml.nvmlDeviceGetPowerUsage(handle) / 1000
|
||||||
|
)
|
||||||
|
time.sleep(interval)
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
stop = False
|
||||||
|
thread = MyThread(_get_gpu_info, params=())
|
||||||
|
thread.start()
|
||||||
|
func()
|
||||||
|
stop = True
|
||||||
|
thread.join()
|
||||||
|
result = thread.get_result()
|
||||||
|
|
||||||
|
# shutdown nvml
|
||||||
|
nvml.nvmlShutdown()
|
||||||
|
max_memory_usage = max(result["gpu_memory_usage"])
|
||||||
|
max_power_usage = max(result["gpu_power_usage"])
|
||||||
|
print("GPU name: %s" % gpu_name)
|
||||||
|
print("GPU device index: %s" % device_idx)
|
||||||
|
print(
|
||||||
|
"Maximum GPU memory usage: %dMiB / %dMiB (%.2f%%)"
|
||||||
|
% (
|
||||||
|
max_memory_usage,
|
||||||
|
gpu_memory_limit,
|
||||||
|
(max_memory_usage / gpu_memory_limit) * 100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Maximum GPU power usage: %dW / %dW (%.2f%%)"
|
||||||
|
% (
|
||||||
|
max_power_usage,
|
||||||
|
gpu_power_limit,
|
||||||
|
(max_power_usage / gpu_power_limit) * 100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Measuring maximum increase of memory usage.")
|
||||||
|
max_usage = memory_usage(func, max_usage=True, interval=interval)
|
||||||
|
print("Maximum increase of RAM memory usage: %d MiB" % max_usage)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
measure_memory(inference)
|
||||||
1742
benchmark/normalizer.json
Normal file
1742
benchmark/normalizer.json
Normal file
File diff suppressed because it is too large
Load Diff
6
benchmark/requirements.benchmark.txt
Normal file
6
benchmark/requirements.benchmark.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
transformers
|
||||||
|
jiwer
|
||||||
|
evaluate
|
||||||
|
datasets
|
||||||
|
memory_profiler
|
||||||
|
py3nvml
|
||||||
31
benchmark/speed_benchmark.py
Normal file
31
benchmark/speed_benchmark.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import argparse
|
||||||
|
import timeit
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from utils import inference
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Speed benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Times an experiment will be run.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def measure_speed(func: Callable[[], None]):
|
||||||
|
# as written in https://docs.python.org/3/library/timeit.html#timeit.Timer.repeat,
|
||||||
|
# min should be taken rather than the average
|
||||||
|
runtimes = timeit.repeat(
|
||||||
|
func,
|
||||||
|
repeat=args.repeat,
|
||||||
|
number=10,
|
||||||
|
)
|
||||||
|
print(runtimes)
|
||||||
|
print("Min execution time: %.3fs" % (min(runtimes) / 10.0))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
measure_speed(inference)
|
||||||
39
benchmark/utils.py
Normal file
39
benchmark/utils.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
model_path = "large-v3"
|
||||||
|
model = WhisperModel(model_path, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def inference():
|
||||||
|
segments, info = model.transcribe("benchmark.m4a", language="fr")
|
||||||
|
for segment in segments:
|
||||||
|
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||||
|
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class MyThread(Thread):
|
||||||
|
def __init__(self, func, params):
|
||||||
|
super(MyThread, self).__init__()
|
||||||
|
self.func = func
|
||||||
|
self.params = params
|
||||||
|
self.result = None
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.result = self.func(*self.params)
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
return self.result
|
||||||
64
benchmark/wer_benchmark.py
Normal file
64
benchmark/wer_benchmark.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from evaluate import load
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
||||||
|
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="WER benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio_numb",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Specify the number of validation audio files in the dataset."
|
||||||
|
" Set to None to retrieve all audio files.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_path = "large-v3"
|
||||||
|
model = WhisperModel(model_path, device="cuda")
|
||||||
|
|
||||||
|
# load the dataset with streaming mode
|
||||||
|
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
|
||||||
|
|
||||||
|
# define the evaluation metric
|
||||||
|
wer_metric = load("wer")
|
||||||
|
|
||||||
|
with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
|
||||||
|
normalizer = EnglishTextNormalizer(json.load(f))
|
||||||
|
|
||||||
|
|
||||||
|
def inference(batch):
|
||||||
|
batch["transcription"] = []
|
||||||
|
for sample in batch["audio"]:
|
||||||
|
segments, info = model.transcribe(sample["array"], language="en")
|
||||||
|
batch["transcription"].append("".join([segment.text for segment in segments]))
|
||||||
|
batch["reference"] = batch["text"]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
dataset = dataset.map(function=inference, batched=True, batch_size=16)
|
||||||
|
|
||||||
|
all_transcriptions = []
|
||||||
|
all_references = []
|
||||||
|
|
||||||
|
# iterate over the dataset and run inference
|
||||||
|
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
|
||||||
|
all_transcriptions.append(result["transcription"])
|
||||||
|
all_references.append(result["reference"])
|
||||||
|
if args.audio_numb and i == (args.audio_numb - 1):
|
||||||
|
break
|
||||||
|
|
||||||
|
# normalize predictions and references
|
||||||
|
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
|
||||||
|
all_references = [normalizer(reference) for reference in all_references]
|
||||||
|
|
||||||
|
# compute the WER metric
|
||||||
|
wer = 100 * wer_metric.compute(
|
||||||
|
predictions=all_transcriptions, references=all_references
|
||||||
|
)
|
||||||
|
print("WER: %.3f" % wer)
|
||||||
6
docker/Dockerfile
Normal file
6
docker/Dockerfile
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04
|
||||||
|
WORKDIR /root
|
||||||
|
RUN apt-get update -y && apt-get install -y python3-pip
|
||||||
|
COPY infer.py jfk.flac ./
|
||||||
|
RUN pip3 install faster-whisper
|
||||||
|
CMD ["python3", "infer.py"]
|
||||||
7
docker/infer.py
Normal file
7
docker/infer.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
jfk_path = "jfk.flac"
|
||||||
|
model = WhisperModel("tiny", device="cuda")
|
||||||
|
segments, info = model.transcribe(jfk_path, word_timestamps=True)
|
||||||
|
for segment in segments:
|
||||||
|
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||||
BIN
docker/jfk.flac
Normal file
BIN
docker/jfk.flac
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
|||||||
from faster_whisper.audio import decode_audio
|
from faster_whisper.audio import decode_audio
|
||||||
from faster_whisper.transcribe import WhisperModel
|
from faster_whisper.transcribe import BatchedInferencePipeline, WhisperModel
|
||||||
from faster_whisper.utils import available_models, download_model, format_timestamp
|
from faster_whisper.utils import available_models, download_model, format_timestamp
|
||||||
from faster_whisper.version import __version__
|
from faster_whisper.version import __version__
|
||||||
|
|
||||||
@@ -7,6 +7,7 @@ __all__ = [
|
|||||||
"available_models",
|
"available_models",
|
||||||
"decode_audio",
|
"decode_audio",
|
||||||
"WhisperModel",
|
"WhisperModel",
|
||||||
|
"BatchedInferencePipeline",
|
||||||
"download_model",
|
"download_model",
|
||||||
"format_timestamp",
|
"format_timestamp",
|
||||||
"__version__",
|
"__version__",
|
||||||
|
|||||||
BIN
faster_whisper/assets/pyannote_vad_model.bin
Normal file
BIN
faster_whisper/assets/pyannote_vad_model.bin
Normal file
Binary file not shown.
Binary file not shown.
@@ -1,19 +1,7 @@
|
|||||||
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
|
|
||||||
|
|
||||||
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
|
|
||||||
system dependencies. FFmpeg does not need to be installed on the system.
|
|
||||||
|
|
||||||
However, the API is quite low-level so we need to manipulate audio frames directly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import io
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
from typing import BinaryIO, Union
|
from typing import BinaryIO, Union
|
||||||
|
|
||||||
import av
|
import torch
|
||||||
import numpy as np
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
def decode_audio(
|
def decode_audio(
|
||||||
@@ -29,91 +17,42 @@ def decode_audio(
|
|||||||
split_stereo: Return separate left and right channels.
|
split_stereo: Return separate left and right channels.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A float32 Numpy array.
|
A float32 Torch Tensor.
|
||||||
|
|
||||||
If `split_stereo` is enabled, the function returns a 2-tuple with the
|
If `split_stereo` is enabled, the function returns a 2-tuple with the
|
||||||
separated left and right channels.
|
separated left and right channels.
|
||||||
"""
|
"""
|
||||||
resampler = av.audio.resampler.AudioResampler(
|
|
||||||
format="s16",
|
waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T
|
||||||
layout="mono" if not split_stereo else "stereo",
|
|
||||||
rate=sampling_rate,
|
if audio_sf != sampling_rate:
|
||||||
|
waveform = torchaudio.functional.resample(
|
||||||
|
waveform, orig_freq=audio_sf, new_freq=sampling_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_buffer = io.BytesIO()
|
|
||||||
dtype = None
|
|
||||||
|
|
||||||
with av.open(input_file, mode="r", metadata_errors="ignore") as container:
|
|
||||||
frames = container.decode(audio=0)
|
|
||||||
frames = _ignore_invalid_frames(frames)
|
|
||||||
frames = _group_frames(frames, 500000)
|
|
||||||
frames = _resample_frames(frames, resampler)
|
|
||||||
|
|
||||||
for frame in frames:
|
|
||||||
array = frame.to_ndarray()
|
|
||||||
dtype = array.dtype
|
|
||||||
raw_buffer.write(array)
|
|
||||||
|
|
||||||
# It appears that some objects related to the resampler are not freed
|
|
||||||
# unless the garbage collector is manually run.
|
|
||||||
del resampler
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
|
|
||||||
|
|
||||||
# Convert s16 back to f32.
|
|
||||||
audio = audio.astype(np.float32) / 32768.0
|
|
||||||
|
|
||||||
if split_stereo:
|
if split_stereo:
|
||||||
left_channel = audio[0::2]
|
return waveform[0], waveform[1]
|
||||||
right_channel = audio[1::2]
|
|
||||||
return left_channel, right_channel
|
|
||||||
|
|
||||||
return audio
|
return waveform.mean(0)
|
||||||
|
|
||||||
|
|
||||||
def _ignore_invalid_frames(frames):
|
|
||||||
iterator = iter(frames)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
yield next(iterator)
|
|
||||||
except StopIteration:
|
|
||||||
break
|
|
||||||
except av.error.InvalidDataError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
def _group_frames(frames, num_samples=None):
|
|
||||||
fifo = av.audio.fifo.AudioFifo()
|
|
||||||
|
|
||||||
for frame in frames:
|
|
||||||
frame.pts = None # Ignore timestamp check.
|
|
||||||
fifo.write(frame)
|
|
||||||
|
|
||||||
if num_samples is not None and fifo.samples >= num_samples:
|
|
||||||
yield fifo.read()
|
|
||||||
|
|
||||||
if fifo.samples > 0:
|
|
||||||
yield fifo.read()
|
|
||||||
|
|
||||||
|
|
||||||
def _resample_frames(frames, resampler):
|
|
||||||
# Add None to flush the resampler.
|
|
||||||
for frame in itertools.chain(frames, [None]):
|
|
||||||
yield from resampler.resample(frame)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_or_trim(array, length: int, *, axis: int = -1):
|
def pad_or_trim(array, length: int, *, axis: int = -1):
|
||||||
"""
|
"""
|
||||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||||
"""
|
"""
|
||||||
|
axis = axis % array.ndim
|
||||||
if array.shape[axis] > length:
|
if array.shape[axis] > length:
|
||||||
array = array.take(indices=range(length), axis=axis)
|
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
|
||||||
|
return array[idx]
|
||||||
|
|
||||||
if array.shape[axis] < length:
|
if array.shape[axis] < length:
|
||||||
pad_widths = [(0, 0)] * array.ndim
|
pad_widths = (
|
||||||
pad_widths[axis] = (0, length - array.shape[axis])
|
[
|
||||||
array = np.pad(array, pad_widths)
|
0,
|
||||||
|
]
|
||||||
|
* array.ndim
|
||||||
|
* 2
|
||||||
|
)
|
||||||
|
pad_widths[2 * axis] = length - array.shape[axis]
|
||||||
|
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
|||||||
@@ -1,16 +1,21 @@
|
|||||||
import numpy as np
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
|
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
|
||||||
class FeatureExtractor:
|
class FeatureExtractor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
device: str = "auto",
|
||||||
feature_size=80,
|
feature_size=80,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
hop_length=160,
|
hop_length=160,
|
||||||
chunk_length=30,
|
chunk_length=30,
|
||||||
n_fft=400,
|
n_fft=400,
|
||||||
):
|
):
|
||||||
|
if device == "auto":
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
self.chunk_length = chunk_length
|
self.chunk_length = chunk_length
|
||||||
@@ -22,21 +27,22 @@ class FeatureExtractor:
|
|||||||
sampling_rate, n_fft, n_mels=feature_size
|
sampling_rate, n_fft, n_mels=feature_size
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
|
@staticmethod
|
||||||
|
def get_mel_filters(sr, n_fft, n_mels=128):
|
||||||
|
"""
|
||||||
|
Implementation of librosa.filters.mel in Pytorch
|
||||||
|
"""
|
||||||
# Initialize the weights
|
# Initialize the weights
|
||||||
n_mels = int(n_mels)
|
n_mels = int(n_mels)
|
||||||
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
|
||||||
|
|
||||||
# Center freqs of each FFT bin
|
# Center freqs of each FFT bin
|
||||||
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||||
|
|
||||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||||
min_mel = 0.0
|
min_mel = 0.0
|
||||||
max_mel = 45.245640471924965
|
max_mel = 45.245640471924965
|
||||||
|
|
||||||
mels = np.linspace(min_mel, max_mel, n_mels + 2)
|
mels = torch.linspace(min_mel, max_mel, n_mels + 2)
|
||||||
|
|
||||||
mels = np.asanyarray(mels)
|
|
||||||
|
|
||||||
# Fill in the linear scale
|
# Fill in the linear scale
|
||||||
f_min = 0.0
|
f_min = 0.0
|
||||||
@@ -46,125 +52,63 @@ class FeatureExtractor:
|
|||||||
# And now the nonlinear scale
|
# And now the nonlinear scale
|
||||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
|
||||||
|
|
||||||
# If we have vector data, vectorize
|
# If we have vector data, vectorize
|
||||||
log_t = mels >= min_log_mel
|
log_t = mels >= min_log_mel
|
||||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
|
||||||
|
|
||||||
mel_f = freqs
|
mel_f = freqs
|
||||||
|
|
||||||
fdiff = np.diff(mel_f)
|
fdiff = torch.diff(mel_f)
|
||||||
ramps = np.subtract.outer(mel_f, fftfreqs)
|
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
|
||||||
|
|
||||||
for i in range(n_mels):
|
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
|
||||||
# lower and upper slopes for all bins
|
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
|
||||||
lower = -ramps[i] / fdiff[i]
|
|
||||||
upper = ramps[i + 2] / fdiff[i + 1]
|
|
||||||
|
|
||||||
# .. then intersect them with each other and zero
|
# Intersect them with each other and zero, vectorized across all i
|
||||||
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
|
||||||
|
|
||||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||||
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
||||||
weights *= enorm[:, np.newaxis]
|
weights *= enorm.unsqueeze(1)
|
||||||
|
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
def fram_wave(self, waveform, center=True):
|
def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
|
||||||
"""
|
"""
|
||||||
Transform a raw waveform into a list of smaller waveforms.
|
Compute the log-Mel spectrogram of the provided audio.
|
||||||
The window length defines how much of the signal is
|
|
||||||
contain in each frame (smalle waveform), while the hope length defines the step
|
|
||||||
between the beginning of each new frame.
|
|
||||||
Centering is done by reflecting the waveform which is first centered around
|
|
||||||
`frame_idx * hop_length`.
|
|
||||||
"""
|
"""
|
||||||
frames = []
|
|
||||||
for i in range(0, waveform.shape[0] + 1, self.hop_length):
|
|
||||||
half_window = (self.n_fft - 1) // 2 + 1
|
|
||||||
if center:
|
|
||||||
start = i - half_window if i > half_window else 0
|
|
||||||
end = (
|
|
||||||
i + half_window
|
|
||||||
if i < waveform.shape[0] - half_window
|
|
||||||
else waveform.shape[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
frame = waveform[start:end]
|
|
||||||
|
|
||||||
if start == 0:
|
|
||||||
padd_width = (-i + half_window, 0)
|
|
||||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
|
||||||
|
|
||||||
elif end == waveform.shape[0]:
|
|
||||||
padd_width = (0, (i - waveform.shape[0] + half_window))
|
|
||||||
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
|
||||||
|
|
||||||
else:
|
|
||||||
frame = waveform[i : i + self.n_fft]
|
|
||||||
frame_width = frame.shape[0]
|
|
||||||
if frame_width < waveform.shape[0]:
|
|
||||||
frame = np.lib.pad(
|
|
||||||
frame,
|
|
||||||
pad_width=(0, self.n_fft - frame_width),
|
|
||||||
mode="constant",
|
|
||||||
constant_values=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
frames.append(frame)
|
|
||||||
return np.stack(frames, 0)
|
|
||||||
|
|
||||||
def stft(self, frames, window):
|
|
||||||
"""
|
|
||||||
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
|
|
||||||
Should give the same results as `torch.stft`.
|
|
||||||
"""
|
|
||||||
frame_size = frames.shape[1]
|
|
||||||
fft_size = self.n_fft
|
|
||||||
|
|
||||||
if fft_size is None:
|
|
||||||
fft_size = frame_size
|
|
||||||
|
|
||||||
if fft_size < frame_size:
|
|
||||||
raise ValueError("FFT size must greater or equal the frame size")
|
|
||||||
# number of FFT bins to store
|
|
||||||
num_fft_bins = (fft_size >> 1) + 1
|
|
||||||
|
|
||||||
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
|
|
||||||
fft_signal = np.zeros(fft_size)
|
|
||||||
|
|
||||||
for f, frame in enumerate(frames):
|
|
||||||
if window is not None:
|
|
||||||
np.multiply(frame, window, out=fft_signal[:frame_size])
|
|
||||||
else:
|
|
||||||
fft_signal[:frame_size] = frame
|
|
||||||
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
|
|
||||||
return data.T
|
|
||||||
|
|
||||||
def __call__(self, waveform, padding=True, chunk_length=None):
|
|
||||||
"""
|
|
||||||
Compute the log-Mel spectrogram of the provided audio, gives similar results
|
|
||||||
whisper's original torch implementation with 1e-5 tolerance.
|
|
||||||
"""
|
|
||||||
if chunk_length is not None:
|
if chunk_length is not None:
|
||||||
self.n_samples = chunk_length * self.sampling_rate
|
self.n_samples = chunk_length * self.sampling_rate
|
||||||
self.nb_max_frames = self.n_samples // self.hop_length
|
self.nb_max_frames = self.n_samples // self.hop_length
|
||||||
|
|
||||||
|
if waveform.dtype is not torch.float32:
|
||||||
|
waveform = waveform.to(torch.float32)
|
||||||
|
|
||||||
|
waveform = (
|
||||||
|
waveform.to(self.device)
|
||||||
|
if self.device == "cuda" and not waveform.is_cuda
|
||||||
|
else waveform
|
||||||
|
)
|
||||||
|
|
||||||
if padding:
|
if padding:
|
||||||
waveform = np.pad(waveform, [(0, self.n_samples)])
|
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))
|
||||||
|
|
||||||
window = np.hanning(self.n_fft + 1)[:-1]
|
window = torch.hann_window(self.n_fft).to(waveform.device)
|
||||||
|
|
||||||
frames = self.fram_wave(waveform)
|
stft = torch.stft(
|
||||||
stft = self.stft(frames, window=window)
|
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
|
||||||
magnitudes = np.abs(stft[:, :-1]) ** 2
|
)
|
||||||
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
|
||||||
filters = self.mel_filters
|
mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
|
||||||
mel_spec = filters @ magnitudes
|
|
||||||
|
|
||||||
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
log_spec = (log_spec + 4.0) / 4.0
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
|
||||||
return log_spec
|
# When the model is running on multiple GPUs, the output should be moved
|
||||||
|
# to the CPU since we don't know which GPU will handle the next job.
|
||||||
|
return log_spec.cpu() if to_cpu else log_spec
|
||||||
|
|||||||
@@ -105,6 +105,42 @@ class Tokenizer:
|
|||||||
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def non_speech_tokens(self) -> Tuple[int]:
|
||||||
|
"""
|
||||||
|
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||||
|
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||||
|
|
||||||
|
- ♪♪♪
|
||||||
|
- ( SPEAKING FOREIGN LANGUAGE )
|
||||||
|
- [DAVID] Hey there,
|
||||||
|
|
||||||
|
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||||
|
"""
|
||||||
|
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||||
|
symbols += (
|
||||||
|
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||||
|
)
|
||||||
|
|
||||||
|
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||||
|
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||||
|
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||||
|
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||||
|
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||||
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||||
|
|
||||||
|
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||||
|
result = {self.encode(" -")[0], self.encode(" '")[0]}
|
||||||
|
for symbol in symbols + list(miscellaneous):
|
||||||
|
for tokens in [
|
||||||
|
self.encode(symbol),
|
||||||
|
self.encode(" " + symbol),
|
||||||
|
]:
|
||||||
|
if len(tokens) == 1 or symbol in miscellaneous:
|
||||||
|
result.add(tokens[0])
|
||||||
|
|
||||||
|
return tuple(sorted(result))
|
||||||
|
|
||||||
def split_to_word_tokens(
|
def split_to_word_tokens(
|
||||||
self, tokens: List[int]
|
self, tokens: List[int]
|
||||||
) -> Tuple[List[str], List[List[int]]]:
|
) -> Tuple[List[str], List[List[int]]]:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -54,8 +54,9 @@ def download_model(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
size_or_id: Size of the model to download from https://huggingface.co/Systran
|
size_or_id: Size of the model to download from https://huggingface.co/Systran
|
||||||
(tiny, tiny.en, base, base.en, small, small.en medium, medium.en, large-v1, large-v2,
|
(tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en,
|
||||||
large-v3, large), or a CTranslate2-converted model ID from the Hugging Face Hub
|
distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2,
|
||||||
|
distil-large-v3), or a CTranslate2-converted model ID from the Hugging Face Hub
|
||||||
(e.g. Systran/faster-whisper-large-v3).
|
(e.g. Systran/faster-whisper-large-v3).
|
||||||
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
||||||
the cache directory.
|
the cache directory.
|
||||||
|
|||||||
@@ -1,11 +1,18 @@
|
|||||||
import bisect
|
import bisect
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
|
|
||||||
from typing import List, NamedTuple, Optional
|
from abc import ABC
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pyannote.audio.core.io import AudioFile
|
||||||
|
from pyannote.audio.pipelines import VoiceActivityDetection
|
||||||
|
from pyannote.audio.pipelines.utils import PipelineModel
|
||||||
|
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
||||||
|
|
||||||
from faster_whisper.utils import get_assets_path
|
from faster_whisper.utils import get_assets_path
|
||||||
|
|
||||||
@@ -25,9 +32,6 @@ class VadOptions(NamedTuple):
|
|||||||
split aggressively just before max_speech_duration_s.
|
split aggressively just before max_speech_duration_s.
|
||||||
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
||||||
before separating it
|
before separating it
|
||||||
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
|
|
||||||
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
|
||||||
Values other than these may affect model performance!!
|
|
||||||
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -35,12 +39,11 @@ class VadOptions(NamedTuple):
|
|||||||
min_speech_duration_ms: int = 250
|
min_speech_duration_ms: int = 250
|
||||||
max_speech_duration_s: float = float("inf")
|
max_speech_duration_s: float = float("inf")
|
||||||
min_silence_duration_ms: int = 2000
|
min_silence_duration_ms: int = 2000
|
||||||
window_size_samples: int = 1024
|
|
||||||
speech_pad_ms: int = 400
|
speech_pad_ms: int = 400
|
||||||
|
|
||||||
|
|
||||||
def get_speech_timestamps(
|
def get_speech_timestamps(
|
||||||
audio: np.ndarray,
|
audio: torch.Tensor,
|
||||||
vad_options: Optional[VadOptions] = None,
|
vad_options: Optional[VadOptions] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
@@ -61,15 +64,8 @@ def get_speech_timestamps(
|
|||||||
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
||||||
max_speech_duration_s = vad_options.max_speech_duration_s
|
max_speech_duration_s = vad_options.max_speech_duration_s
|
||||||
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
||||||
window_size_samples = vad_options.window_size_samples
|
window_size_samples = 512
|
||||||
speech_pad_ms = vad_options.speech_pad_ms
|
speech_pad_ms = vad_options.speech_pad_ms
|
||||||
|
|
||||||
if window_size_samples not in [512, 1024, 1536]:
|
|
||||||
warnings.warn(
|
|
||||||
"Unusual window_size_samples! Supported window_size_samples:\n"
|
|
||||||
" - [512, 1024, 1536] for 16000 sampling_rate"
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_rate = 16000
|
sampling_rate = 16000
|
||||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||||
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||||
@@ -84,14 +80,14 @@ def get_speech_timestamps(
|
|||||||
audio_length_samples = len(audio)
|
audio_length_samples = len(audio)
|
||||||
|
|
||||||
model = get_vad_model()
|
model = get_vad_model()
|
||||||
state = model.get_initial_state(batch_size=1)
|
state, context = model.get_initial_states(batch_size=1)
|
||||||
|
|
||||||
speech_probs = []
|
speech_probs = []
|
||||||
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
||||||
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
|
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
|
||||||
if len(chunk) < window_size_samples:
|
if len(chunk) < window_size_samples:
|
||||||
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||||
speech_prob, state = model(chunk, state, sampling_rate)
|
speech_prob, state, context = model(chunk, state, context, sampling_rate)
|
||||||
speech_probs.append(speech_prob)
|
speech_probs.append(speech_prob)
|
||||||
|
|
||||||
triggered = False
|
triggered = False
|
||||||
@@ -188,12 +184,12 @@ def get_speech_timestamps(
|
|||||||
return speeches
|
return speeches
|
||||||
|
|
||||||
|
|
||||||
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
def collect_chunks(audio: torch.Tensor, chunks: List[dict]) -> torch.Tensor:
|
||||||
"""Collects and concatenates audio chunks."""
|
"""Collects and concatenates audio chunks."""
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return np.array([], dtype=np.float32)
|
return torch.tensor([], dtype=torch.float32)
|
||||||
|
|
||||||
return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
|
return torch.cat([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
|
||||||
|
|
||||||
|
|
||||||
class SpeechTimestampsMap:
|
class SpeechTimestampsMap:
|
||||||
@@ -261,12 +257,12 @@ class SileroVADModel:
|
|||||||
sess_options=opts,
|
sess_options=opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_initial_state(self, batch_size: int):
|
def get_initial_states(self, batch_size: int):
|
||||||
h = np.zeros((2, batch_size, 64), dtype=np.float32)
|
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
||||||
c = np.zeros((2, batch_size, 64), dtype=np.float32)
|
context = np.zeros((batch_size, 64), dtype=np.float32)
|
||||||
return h, c
|
return state, context
|
||||||
|
|
||||||
def __call__(self, x, state, sr: int):
|
def __call__(self, x, state, context, sr: int):
|
||||||
if len(x.shape) == 1:
|
if len(x.shape) == 1:
|
||||||
x = np.expand_dims(x, 0)
|
x = np.expand_dims(x, 0)
|
||||||
if len(x.shape) > 2:
|
if len(x.shape) > 2:
|
||||||
@@ -276,16 +272,325 @@ class SileroVADModel:
|
|||||||
if sr / x.shape[1] > 31.25:
|
if sr / x.shape[1] > 31.25:
|
||||||
raise ValueError("Input audio chunk is too short")
|
raise ValueError("Input audio chunk is too short")
|
||||||
|
|
||||||
h, c = state
|
x = np.concatenate([context, x], axis=1)
|
||||||
|
|
||||||
ort_inputs = {
|
ort_inputs = {
|
||||||
"input": x,
|
"input": x,
|
||||||
"h": h,
|
"state": state,
|
||||||
"c": c,
|
|
||||||
"sr": np.array(sr, dtype="int64"),
|
"sr": np.array(sr, dtype="int64"),
|
||||||
}
|
}
|
||||||
|
|
||||||
out, h, c = self.session.run(None, ort_inputs)
|
out, state = self.session.run(None, ort_inputs)
|
||||||
state = (h, c)
|
context = x[..., -64:]
|
||||||
|
|
||||||
return out, state
|
return out, state, context
|
||||||
|
|
||||||
|
|
||||||
|
# BSD 2-Clause License
|
||||||
|
|
||||||
|
# Copyright (c) 2024, Max Bain
|
||||||
|
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
|
||||||
|
# The code below is copied from whisper-x (https://github.com/m-bain/whisperX)
|
||||||
|
# and adapted for faster_whisper.
|
||||||
|
class SegmentX:
|
||||||
|
def __init__(self, start, end, speaker=None):
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.speaker = speaker
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceActivitySegmentation(VoiceActivityDetection, ABC):
|
||||||
|
"""Pipeline wrapper class for Voice Activity Segmentation based on VAD scores."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
segmentation: PipelineModel = "pyannote/segmentation",
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
fscore: bool = False,
|
||||||
|
use_auth_token: Optional[str] = None,
|
||||||
|
**inference_kwargs,
|
||||||
|
):
|
||||||
|
"""Initialize the pipeline with the model name and the optional device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dict parameters of VoiceActivityDetection class from pyannote:
|
||||||
|
segmentation (PipelineModel): Loaded model name.
|
||||||
|
device (torch.device or None): Device to perform the segmentation.
|
||||||
|
fscore (bool): Flag indicating whether to compute F-score during inference.
|
||||||
|
use_auth_token (str or None): Optional authentication token for model access.
|
||||||
|
inference_kwargs (dict): Additional arguments from VoiceActivityDetection pipeline.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
segmentation=segmentation,
|
||||||
|
device=device,
|
||||||
|
fscore=fscore,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
**inference_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self, file: AudioFile, hook: Optional[Callable] = None
|
||||||
|
) -> SlidingWindowFeature:
|
||||||
|
"""Apply voice activity detection on the audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (AudioFile): Processed file.
|
||||||
|
hook (callable): Hook called with signature: hook("step_name", step_artefact, file=file)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
segmentations (SlidingWindowFeature): Voice activity segmentation.
|
||||||
|
"""
|
||||||
|
# setup hook (e.g. for debugging purposes)
|
||||||
|
hook = self.setup_hook(file, hook=hook)
|
||||||
|
|
||||||
|
# apply segmentation model if needed
|
||||||
|
# output shape is (num_chunks, num_frames, 1)
|
||||||
|
if self.training:
|
||||||
|
if self.CACHED_SEGMENTATION in file:
|
||||||
|
segmentations = file[self.CACHED_SEGMENTATION]
|
||||||
|
else:
|
||||||
|
segmentations = self._segmentation(file)
|
||||||
|
file[self.CACHED_SEGMENTATION] = segmentations
|
||||||
|
else:
|
||||||
|
segmentations: SlidingWindowFeature = self._segmentation(file)
|
||||||
|
|
||||||
|
return segmentations
|
||||||
|
|
||||||
|
|
||||||
|
class BinarizeVadScores:
|
||||||
|
"""Binarize detection scores using hysteresis thresholding.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
||||||
|
RNN-based Voice Activity Detection", InterSpeech 2015.
|
||||||
|
|
||||||
|
Modified by Max Bain to include WhisperX's min-cut operation
|
||||||
|
https://arxiv.org/abs/2303.00747
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
onset: float = 0.5,
|
||||||
|
offset: Optional[float] = None,
|
||||||
|
min_duration_on: float = 0.0,
|
||||||
|
min_duration_off: float = 0.0,
|
||||||
|
pad_onset: float = 0.0,
|
||||||
|
pad_offset: float = 0.0,
|
||||||
|
max_duration: float = float("inf"),
|
||||||
|
):
|
||||||
|
"""Initializes the parameters for Binarizing the VAD scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
onset (float, optional):
|
||||||
|
Onset threshold. Defaults to 0.5.
|
||||||
|
offset (float, optional):
|
||||||
|
Offset threshold. Defaults to `onset`.
|
||||||
|
min_duration_on (float, optional):
|
||||||
|
Remove active regions shorter than that many seconds. Defaults to 0s.
|
||||||
|
min_duration_off (float, optional):
|
||||||
|
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
||||||
|
pad_onset (float, optional):
|
||||||
|
Extend active regions by moving their start time by that many seconds.
|
||||||
|
Defaults to 0s.
|
||||||
|
pad_offset (float, optional):
|
||||||
|
Extend active regions by moving their end time by that many seconds.
|
||||||
|
Defaults to 0s.
|
||||||
|
max_duration (float):
|
||||||
|
The maximum length of an active segment.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.onset = onset
|
||||||
|
self.offset = offset or onset
|
||||||
|
|
||||||
|
self.pad_onset = pad_onset
|
||||||
|
self.pad_offset = pad_offset
|
||||||
|
|
||||||
|
self.min_duration_on = min_duration_on
|
||||||
|
self.min_duration_off = min_duration_off
|
||||||
|
|
||||||
|
self.max_duration = max_duration
|
||||||
|
|
||||||
|
def __get_active_regions(self, scores: SlidingWindowFeature) -> Annotation:
|
||||||
|
"""Extract active regions from VAD scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scores (SlidingWindowFeature): Detection scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
active (Annotation): Active regions.
|
||||||
|
"""
|
||||||
|
num_frames, num_classes = scores.data.shape
|
||||||
|
frames = scores.sliding_window
|
||||||
|
timestamps = [frames[i].middle for i in range(num_frames)]
|
||||||
|
# annotation meant to store 'active' regions
|
||||||
|
active = Annotation()
|
||||||
|
for k, k_scores in enumerate(scores.data.T):
|
||||||
|
label = k if scores.labels is None else scores.labels[k]
|
||||||
|
|
||||||
|
# initial state
|
||||||
|
start = timestamps[0]
|
||||||
|
is_active = k_scores[0] > self.onset
|
||||||
|
curr_scores = [k_scores[0]]
|
||||||
|
curr_timestamps = [start]
|
||||||
|
t = start
|
||||||
|
# optionally add `strict=False` for python 3.10 or later
|
||||||
|
for t, y in zip(timestamps[1:], k_scores[1:]):
|
||||||
|
# currently active
|
||||||
|
if is_active:
|
||||||
|
curr_duration = t - start
|
||||||
|
if curr_duration > self.max_duration:
|
||||||
|
search_after = len(curr_scores) // 2
|
||||||
|
# divide segment
|
||||||
|
min_score_div_idx = search_after + np.argmin(
|
||||||
|
curr_scores[search_after:]
|
||||||
|
)
|
||||||
|
min_score_t = curr_timestamps[min_score_div_idx]
|
||||||
|
region = Segment(
|
||||||
|
start - self.pad_onset, min_score_t + self.pad_offset
|
||||||
|
)
|
||||||
|
active[region, k] = label
|
||||||
|
start = curr_timestamps[min_score_div_idx]
|
||||||
|
curr_scores = curr_scores[min_score_div_idx + 1 :]
|
||||||
|
curr_timestamps = curr_timestamps[min_score_div_idx + 1 :]
|
||||||
|
# switching from active to inactive
|
||||||
|
elif y < self.offset:
|
||||||
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||||
|
active[region, k] = label
|
||||||
|
start = t
|
||||||
|
is_active = False
|
||||||
|
curr_scores = []
|
||||||
|
curr_timestamps = []
|
||||||
|
curr_scores.append(y)
|
||||||
|
curr_timestamps.append(t)
|
||||||
|
# currently inactive
|
||||||
|
else:
|
||||||
|
# switching from inactive to active
|
||||||
|
if y > self.onset:
|
||||||
|
start = t
|
||||||
|
is_active = True
|
||||||
|
|
||||||
|
# if active at the end, add final region
|
||||||
|
if is_active:
|
||||||
|
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
||||||
|
active[region, k] = label
|
||||||
|
|
||||||
|
return active
|
||||||
|
|
||||||
|
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
||||||
|
"""Binarize detection scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scores (SlidingWindowFeature): Detection scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
active (Annotation): Binarized scores.
|
||||||
|
"""
|
||||||
|
active = self.__get_active_regions(scores)
|
||||||
|
# because of padding, some active regions might be overlapping: merge them.
|
||||||
|
# also: fill same speaker gaps shorter than min_duration_off
|
||||||
|
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
||||||
|
if self.max_duration < float("inf"):
|
||||||
|
raise NotImplementedError("This would break current max_duration param")
|
||||||
|
active = active.support(collar=self.min_duration_off)
|
||||||
|
|
||||||
|
# remove tracks shorter than min_duration_on
|
||||||
|
if self.min_duration_on > 0:
|
||||||
|
for segment, track in list(active.itertracks()):
|
||||||
|
if segment.duration < self.min_duration_on:
|
||||||
|
del active[segment, track]
|
||||||
|
|
||||||
|
return active
|
||||||
|
|
||||||
|
|
||||||
|
def merge_chunks(
|
||||||
|
segments,
|
||||||
|
chunk_length,
|
||||||
|
onset: float = 0.5,
|
||||||
|
offset: Optional[float] = None,
|
||||||
|
edge_padding: float = 0.1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Merge operation described in whisper-x paper
|
||||||
|
"""
|
||||||
|
curr_end = 0
|
||||||
|
merged_segments = []
|
||||||
|
seg_idxs = []
|
||||||
|
speaker_idxs = []
|
||||||
|
|
||||||
|
assert chunk_length > 0
|
||||||
|
binarize = BinarizeVadScores(max_duration=chunk_length, onset=onset, offset=offset)
|
||||||
|
segments = binarize(segments)
|
||||||
|
segments_list = []
|
||||||
|
for speech_turn in segments.get_timeline():
|
||||||
|
segments_list.append(
|
||||||
|
SegmentX(
|
||||||
|
max(0.0, speech_turn.start - edge_padding),
|
||||||
|
speech_turn.end + edge_padding,
|
||||||
|
"UNKNOWN",
|
||||||
|
)
|
||||||
|
) # 100ms edge padding to account for edge errors
|
||||||
|
|
||||||
|
if len(segments_list) == 0:
|
||||||
|
print("No active speech found in audio")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Make sur the starting point is the start of the segment.
|
||||||
|
curr_start = segments_list[0].start
|
||||||
|
|
||||||
|
for idx, seg in enumerate(segments_list):
|
||||||
|
# if any segment start timing is less than previous segment end timing,
|
||||||
|
# reset the edge padding. Similarly for end timing.
|
||||||
|
if idx > 0:
|
||||||
|
if seg.start < segments_list[idx - 1].end:
|
||||||
|
seg.start += edge_padding
|
||||||
|
if idx < len(segments_list) - 1:
|
||||||
|
if seg.end > segments_list[idx + 1].start:
|
||||||
|
seg.end -= edge_padding
|
||||||
|
|
||||||
|
if seg.end - curr_start > chunk_length and curr_end - curr_start > 0:
|
||||||
|
merged_segments.append(
|
||||||
|
{
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
curr_start = seg.start
|
||||||
|
seg_idxs = []
|
||||||
|
speaker_idxs = []
|
||||||
|
curr_end = seg.end
|
||||||
|
seg_idxs.append((seg.start, seg.end))
|
||||||
|
speaker_idxs.append(seg.speaker)
|
||||||
|
# add final
|
||||||
|
merged_segments.append(
|
||||||
|
{
|
||||||
|
"start": curr_start,
|
||||||
|
"end": curr_end,
|
||||||
|
"segments": seg_idxs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return merged_segments
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
"""Version information."""
|
"""Version information."""
|
||||||
|
|
||||||
__version__ = "1.0.1"
|
__version__ = "1.0.3"
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
av==11.*
|
|
||||||
ctranslate2>=4.0,<5
|
ctranslate2>=4.0,<5
|
||||||
huggingface_hub>=0.13
|
huggingface_hub>=0.13
|
||||||
tokenizers>=0.13,<0.16
|
tokenizers>=0.13,<1
|
||||||
onnxruntime>=1.14,<2
|
onnxruntime>=1.14,<2
|
||||||
|
pyannote-audio>=3.1.1
|
||||||
|
torch>=2.1.1
|
||||||
|
torchaudio>=2.1.2
|
||||||
|
tqdm
|
||||||
@@ -11,3 +11,8 @@ def data_dir():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def jfk_path(data_dir):
|
def jfk_path(data_dir):
|
||||||
return os.path.join(data_dir, "jfk.flac")
|
return os.path.join(data_dir, "jfk.flac")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def physcisworks_path(data_dir):
|
||||||
|
return os.path.join(data_dir, "physicsworks.wav")
|
||||||
|
|||||||
BIN
tests/data/physicsworks.wav
Normal file
BIN
tests/data/physicsworks.wav
Normal file
Binary file not shown.
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from faster_whisper import WhisperModel, decode_audio
|
from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
|
||||||
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
from faster_whisper.transcribe import get_suppressed_tokens
|
||||||
|
|
||||||
|
|
||||||
def test_supported_languages():
|
def test_supported_languages():
|
||||||
@@ -37,6 +39,50 @@ def test_transcribe(jfk_path):
|
|||||||
assert segment.text == "".join(word.word for word in segment.words)
|
assert segment.text == "".join(word.word for word in segment.words)
|
||||||
assert segment.start == segment.words[0].start
|
assert segment.start == segment.words[0].start
|
||||||
assert segment.end == segment.words[-1].end
|
assert segment.end == segment.words[-1].end
|
||||||
|
batched_model = BatchedInferencePipeline(model=model, use_vad_model=False)
|
||||||
|
result, info = batched_model.transcribe(jfk_path, word_timestamps=True)
|
||||||
|
assert info.language == "en"
|
||||||
|
assert info.language_probability > 0.7
|
||||||
|
segments = []
|
||||||
|
for segment in result:
|
||||||
|
segments.append(
|
||||||
|
{"start": segment.start, "end": segment.end, "text": segment.text}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(segments) == 1
|
||||||
|
assert segment.text == (
|
||||||
|
" And so my fellow Americans ask not what your country can do for you, "
|
||||||
|
"ask what you can do for your country."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batched_transcribe(physcisworks_path):
|
||||||
|
model = WhisperModel("tiny")
|
||||||
|
batched_model = BatchedInferencePipeline(model=model)
|
||||||
|
result, info = batched_model.transcribe(physcisworks_path, batch_size=16)
|
||||||
|
assert info.language == "en"
|
||||||
|
assert info.language_probability > 0.7
|
||||||
|
segments = []
|
||||||
|
for segment in result:
|
||||||
|
segments.append(
|
||||||
|
{"start": segment.start, "end": segment.end, "text": segment.text}
|
||||||
|
)
|
||||||
|
# number of near 30 sec segments
|
||||||
|
assert len(segments) == 8
|
||||||
|
|
||||||
|
result, info = batched_model.transcribe(
|
||||||
|
physcisworks_path,
|
||||||
|
batch_size=16,
|
||||||
|
without_timestamps=False,
|
||||||
|
word_timestamps=True,
|
||||||
|
)
|
||||||
|
segments = []
|
||||||
|
for segment in result:
|
||||||
|
assert segment.words is not None
|
||||||
|
segments.append(
|
||||||
|
{"start": segment.start, "end": segment.end, "text": segment.text}
|
||||||
|
)
|
||||||
|
assert len(segments) > 8
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_with_timestamps(jfk_path):
|
def test_prefix_with_timestamps(jfk_path):
|
||||||
@@ -97,3 +143,116 @@ def test_stereo_diarization(data_dir):
|
|||||||
segments, _ = model.transcribe(right)
|
segments, _ = model.transcribe(right)
|
||||||
transcription = "".join(segment.text for segment in segments).strip()
|
transcription = "".join(segment.text for segment in segments).strip()
|
||||||
assert transcription == "The horizon seems extremely distant."
|
assert transcription == "The horizon seems extremely distant."
|
||||||
|
|
||||||
|
|
||||||
|
def test_multisegment_lang_id(physcisworks_path):
|
||||||
|
model = WhisperModel("tiny")
|
||||||
|
language_info = model.detect_language_multi_segment(physcisworks_path)
|
||||||
|
assert language_info["language_code"] == "en"
|
||||||
|
assert language_info["language_confidence"] > 0.8
|
||||||
|
|
||||||
|
|
||||||
|
def test_suppressed_tokens_minus_1():
|
||||||
|
model = WhisperModel("tiny.en")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(model.hf_tokenizer, False)
|
||||||
|
tokens = get_suppressed_tokens(tokenizer, [-1])
|
||||||
|
assert tokens == (
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
14,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
31,
|
||||||
|
58,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
62,
|
||||||
|
63,
|
||||||
|
90,
|
||||||
|
91,
|
||||||
|
92,
|
||||||
|
93,
|
||||||
|
357,
|
||||||
|
366,
|
||||||
|
438,
|
||||||
|
532,
|
||||||
|
685,
|
||||||
|
705,
|
||||||
|
796,
|
||||||
|
930,
|
||||||
|
1058,
|
||||||
|
1220,
|
||||||
|
1267,
|
||||||
|
1279,
|
||||||
|
1303,
|
||||||
|
1343,
|
||||||
|
1377,
|
||||||
|
1391,
|
||||||
|
1635,
|
||||||
|
1782,
|
||||||
|
1875,
|
||||||
|
2162,
|
||||||
|
2361,
|
||||||
|
2488,
|
||||||
|
3467,
|
||||||
|
4008,
|
||||||
|
4211,
|
||||||
|
4600,
|
||||||
|
4808,
|
||||||
|
5299,
|
||||||
|
5855,
|
||||||
|
6329,
|
||||||
|
7203,
|
||||||
|
9609,
|
||||||
|
9959,
|
||||||
|
10563,
|
||||||
|
10786,
|
||||||
|
11420,
|
||||||
|
11709,
|
||||||
|
11907,
|
||||||
|
13163,
|
||||||
|
13697,
|
||||||
|
13700,
|
||||||
|
14808,
|
||||||
|
15306,
|
||||||
|
16410,
|
||||||
|
16791,
|
||||||
|
17992,
|
||||||
|
19203,
|
||||||
|
19510,
|
||||||
|
20724,
|
||||||
|
22305,
|
||||||
|
22935,
|
||||||
|
27007,
|
||||||
|
30109,
|
||||||
|
30420,
|
||||||
|
33409,
|
||||||
|
34949,
|
||||||
|
40283,
|
||||||
|
40493,
|
||||||
|
40549,
|
||||||
|
47282,
|
||||||
|
49146,
|
||||||
|
50257,
|
||||||
|
50357,
|
||||||
|
50358,
|
||||||
|
50359,
|
||||||
|
50360,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_suppressed_tokens_minus_value():
|
||||||
|
model = WhisperModel("tiny.en")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(model.hf_tokenizer, False)
|
||||||
|
tokens = get_suppressed_tokens(tokenizer, [13])
|
||||||
|
assert tokens == (13, 50257, 50357, 50358, 50359, 50360)
|
||||||
|
|||||||
Reference in New Issue
Block a user