Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
28a4d11a73
|
|||
|
6e42088656
|
|||
|
|
d57c5b40b0 | ||
|
|
83a368e98a | ||
|
|
eb8390233c | ||
|
4a59bb011d
|
|||
|
|
fbcf58bf98 | ||
|
|
1195359984 | ||
|
|
c22db5125d | ||
|
|
8862bee1f8 | ||
|
|
8d400e9870 | ||
|
|
bced5f04c0 | ||
|
|
65551c081f | ||
|
|
f53be1e811 | ||
|
|
4acdb5c619 | ||
|
|
a1c3583c96 | ||
|
|
2036d12634 | ||
|
|
2f6913efc8 | ||
|
|
e11d58599d | ||
|
|
49a80eb8a8 | ||
|
|
8d5e6d56d9 | ||
|
|
6eec07739e | ||
|
|
847fec4492 | ||
|
|
46080e584e | ||
|
|
3d1de60ef3 | ||
|
4ee1d54c14
|
|||
|
e50d82c18c
|
|||
|
4b64ef1f70
|
|||
|
d04e685ca2
|
|||
|
b835bdaaf1
|
|||
|
9f24e2c735
|
|||
|
9a646b69e6
|
|||
|
49af9564ab
|
18
README.md
18
README.md
@@ -75,28 +75,35 @@ Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. T
|
|||||||
|
|
||||||
GPU execution requires the following NVIDIA libraries to be installed:
|
GPU execution requires the following NVIDIA libraries to be installed:
|
||||||
|
|
||||||
* [cuBLAS for CUDA 11](https://developer.nvidia.com/cublas)
|
* [cuBLAS for CUDA 12](https://developer.nvidia.com/cublas)
|
||||||
* [cuDNN 8 for CUDA 11](https://developer.nvidia.com/cudnn)
|
* [cuDNN 8 for CUDA 12](https://developer.nvidia.com/cudnn)
|
||||||
|
|
||||||
There are multiple ways to install these libraries. The recommended way is described in the official NVIDIA documentation, but we also suggest other installation methods below.
|
**Note**: Latest versions of `ctranslate2` support CUDA 12 only. For CUDA 11, the current workaround is downgrading to the `3.24.0` version of `ctranslate2` (This can be done with `pip install --force-reinstall ctranslate2==3.24.0` or specifying the version in a `requirements.txt`).
|
||||||
|
|
||||||
|
There are multiple ways to install the NVIDIA libraries mentioned above. The recommended way is described in the official NVIDIA documentation, but we also suggest other installation methods below.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Other installation methods (click to expand)</summary>
|
<summary>Other installation methods (click to expand)</summary>
|
||||||
|
|
||||||
|
|
||||||
|
**Note:** For all these methods below, keep in mind the above note regarding CUDA versions. Depending on your setup, you may need to install the _CUDA 11_ versions of libraries that correspond to the CUDA 12 libraries listed in the instructions below.
|
||||||
|
|
||||||
#### Use Docker
|
#### Use Docker
|
||||||
|
|
||||||
The libraries are installed in this official NVIDIA Docker image: `nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04`.
|
The libraries (cuBLAS, cuDNN) are installed in these official NVIDIA CUDA Docker images: `nvidia/cuda:12.0.0-runtime-ubuntu20.04` or `nvidia/cuda:12.0.0-runtime-ubuntu22.04`.
|
||||||
|
|
||||||
#### Install with `pip` (Linux only)
|
#### Install with `pip` (Linux only)
|
||||||
|
|
||||||
On Linux these libraries can be installed with `pip`. Note that `LD_LIBRARY_PATH` must be set before launching Python.
|
On Linux these libraries can be installed with `pip`. Note that `LD_LIBRARY_PATH` must be set before launching Python.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install nvidia-cublas-cu11 nvidia-cudnn-cu11
|
pip install nvidia-cublas-cu12 nvidia-cudnn-cu12
|
||||||
|
|
||||||
export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
|
export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Note**: Version 9+ of `nvidia-cudnn-cu12` appears to cause issues due its reliance on cuDNN 9 (Faster-Whisper does not currently support cuDNN 9). Ensure your version of the Python package is for cuDNN 8.
|
||||||
|
|
||||||
#### Download the libraries from Purfview's repository (Windows & Linux)
|
#### Download the libraries from Purfview's repository (Windows & Linux)
|
||||||
|
|
||||||
Purfview's [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) provides the required NVIDIA libraries for Windows & Linux in a [single archive](https://github.com/Purfview/whisper-standalone-win/releases/tag/libs). Decompress the archive and place the libraries in a directory included in the `PATH`.
|
Purfview's [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) provides the required NVIDIA libraries for Windows & Linux in a [single archive](https://github.com/Purfview/whisper-standalone-win/releases/tag/libs). Decompress the archive and place the libraries in a directory included in the `PATH`.
|
||||||
@@ -227,6 +234,7 @@ See more model and transcription options in the [`WhisperModel`](https://github.
|
|||||||
Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list!
|
Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list!
|
||||||
|
|
||||||
|
|
||||||
|
* [faster-whisper-server](https://github.com/fedirz/faster-whisper-server) is an OpenAI compatible server using `faster-whisper`. It's easily deployable with Docker, works with OpenAI SDKs/CLI, supports streaming, and live transcription.
|
||||||
* [WhisperX](https://github.com/m-bain/whisperX) is an award-winning Python library that offers speaker diarization and accurate word-level timestamps using wav2vec2 alignment
|
* [WhisperX](https://github.com/m-bain/whisperX) is an award-winning Python library that offers speaker diarization and accurate word-level timestamps using wav2vec2 alignment
|
||||||
* [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper.
|
* [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper.
|
||||||
* [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo.
|
* [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo.
|
||||||
|
|||||||
BIN
benchmark/benchmark.m4a
Normal file
BIN
benchmark/benchmark.m4a
Normal file
Binary file not shown.
94
benchmark/memory_benchmark.py
Normal file
94
benchmark/memory_benchmark.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import py3nvml.py3nvml as nvml
|
||||||
|
|
||||||
|
from memory_profiler import memory_usage
|
||||||
|
from utils import MyThread, get_logger, inference
|
||||||
|
|
||||||
|
logger = get_logger("faster-whisper")
|
||||||
|
parser = argparse.ArgumentParser(description="Memory benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu_memory", action="store_true", help="Measure GPU memory usage"
|
||||||
|
)
|
||||||
|
parser.add_argument("--device-index", type=int, default=0, help="GPU device index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--interval",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Interval at which measurements are collected",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
device_idx = args.device_index
|
||||||
|
interval = args.interval
|
||||||
|
|
||||||
|
|
||||||
|
def measure_memory(func: Callable[[], None]):
|
||||||
|
if args.gpu_memory:
|
||||||
|
logger.info(
|
||||||
|
"Measuring maximum GPU memory usage on GPU device."
|
||||||
|
" Make sure to not have additional processes running on the same GPU."
|
||||||
|
)
|
||||||
|
# init nvml
|
||||||
|
nvml.nvmlInit()
|
||||||
|
handle = nvml.nvmlDeviceGetHandleByIndex(device_idx)
|
||||||
|
gpu_name = nvml.nvmlDeviceGetName(handle)
|
||||||
|
gpu_memory_limit = nvml.nvmlDeviceGetMemoryInfo(handle).total >> 20
|
||||||
|
gpu_power_limit = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000.0
|
||||||
|
info = {"gpu_memory_usage": [], "gpu_power_usage": []}
|
||||||
|
|
||||||
|
def _get_gpu_info():
|
||||||
|
while True:
|
||||||
|
info["gpu_memory_usage"].append(
|
||||||
|
nvml.nvmlDeviceGetMemoryInfo(handle).used >> 20
|
||||||
|
)
|
||||||
|
info["gpu_power_usage"].append(
|
||||||
|
nvml.nvmlDeviceGetPowerUsage(handle) / 1000
|
||||||
|
)
|
||||||
|
time.sleep(interval)
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
stop = False
|
||||||
|
thread = MyThread(_get_gpu_info, params=())
|
||||||
|
thread.start()
|
||||||
|
func()
|
||||||
|
stop = True
|
||||||
|
thread.join()
|
||||||
|
result = thread.get_result()
|
||||||
|
|
||||||
|
# shutdown nvml
|
||||||
|
nvml.nvmlShutdown()
|
||||||
|
max_memory_usage = max(result["gpu_memory_usage"])
|
||||||
|
max_power_usage = max(result["gpu_power_usage"])
|
||||||
|
print("GPU name: %s" % gpu_name)
|
||||||
|
print("GPU device index: %s" % device_idx)
|
||||||
|
print(
|
||||||
|
"Maximum GPU memory usage: %dMiB / %dMiB (%.2f%%)"
|
||||||
|
% (
|
||||||
|
max_memory_usage,
|
||||||
|
gpu_memory_limit,
|
||||||
|
(max_memory_usage / gpu_memory_limit) * 100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Maximum GPU power usage: %dW / %dW (%.2f%%)"
|
||||||
|
% (
|
||||||
|
max_power_usage,
|
||||||
|
gpu_power_limit,
|
||||||
|
(max_power_usage / gpu_power_limit) * 100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Measuring maximum increase of memory usage.")
|
||||||
|
max_usage = memory_usage(func, max_usage=True, interval=interval)
|
||||||
|
print("Maximum increase of RAM memory usage: %d MiB" % max_usage)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
measure_memory(inference)
|
||||||
1742
benchmark/normalizer.json
Normal file
1742
benchmark/normalizer.json
Normal file
File diff suppressed because it is too large
Load Diff
6
benchmark/requirements.benchmark.txt
Normal file
6
benchmark/requirements.benchmark.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
transformers
|
||||||
|
jiwer
|
||||||
|
evaluate
|
||||||
|
datasets
|
||||||
|
memory_profiler
|
||||||
|
py3nvml
|
||||||
31
benchmark/speed_benchmark.py
Normal file
31
benchmark/speed_benchmark.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import argparse
|
||||||
|
import timeit
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from utils import inference
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Speed benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Times an experiment will be run.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def measure_speed(func: Callable[[], None]):
|
||||||
|
# as written in https://docs.python.org/3/library/timeit.html#timeit.Timer.repeat,
|
||||||
|
# min should be taken rather than the average
|
||||||
|
runtimes = timeit.repeat(
|
||||||
|
func,
|
||||||
|
repeat=args.repeat,
|
||||||
|
number=10,
|
||||||
|
)
|
||||||
|
print(runtimes)
|
||||||
|
print("Min execution time: %.3fs" % (min(runtimes) / 10.0))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
measure_speed(inference)
|
||||||
39
benchmark/utils.py
Normal file
39
benchmark/utils.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
model_path = "large-v3"
|
||||||
|
model = WhisperModel(model_path, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def inference():
|
||||||
|
segments, info = model.transcribe("benchmark.m4a", language="fr")
|
||||||
|
for segment in segments:
|
||||||
|
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||||
|
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class MyThread(Thread):
|
||||||
|
def __init__(self, func, params):
|
||||||
|
super(MyThread, self).__init__()
|
||||||
|
self.func = func
|
||||||
|
self.params = params
|
||||||
|
self.result = None
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.result = self.func(*self.params)
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
return self.result
|
||||||
61
benchmark/wer_benchmark.py
Normal file
61
benchmark/wer_benchmark.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from evaluate import load
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
||||||
|
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="WER benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio_numb",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Specify the number of validation audio files in the dataset."
|
||||||
|
" Set to None to retrieve all audio files.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_path = "large-v3"
|
||||||
|
model = WhisperModel(model_path, device="cuda")
|
||||||
|
|
||||||
|
# load the dataset with streaming mode
|
||||||
|
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
|
||||||
|
|
||||||
|
# define the evaluation metric
|
||||||
|
wer_metric = load("wer")
|
||||||
|
normalizer = EnglishTextNormalizer(json.load(open("normalizer.json")))
|
||||||
|
|
||||||
|
|
||||||
|
def inference(batch):
|
||||||
|
batch["transcription"] = []
|
||||||
|
for sample in batch["audio"]:
|
||||||
|
segments, info = model.transcribe(sample["array"], language="en")
|
||||||
|
batch["transcription"].append("".join([segment.text for segment in segments]))
|
||||||
|
batch["reference"] = batch["text"]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
dataset = dataset.map(function=inference, batched=True, batch_size=16)
|
||||||
|
|
||||||
|
all_transcriptions = []
|
||||||
|
all_references = []
|
||||||
|
|
||||||
|
# iterate over the dataset and run inference
|
||||||
|
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
|
||||||
|
all_transcriptions.append(result["transcription"])
|
||||||
|
all_references.append(result["reference"])
|
||||||
|
if args.audio_numb and i == (args.audio_numb - 1):
|
||||||
|
break
|
||||||
|
|
||||||
|
# normalize predictions and references
|
||||||
|
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
|
||||||
|
all_references = [normalizer(reference) for reference in all_references]
|
||||||
|
|
||||||
|
# compute the WER metric
|
||||||
|
wer = 100 * wer_metric.compute(
|
||||||
|
predictions=all_transcriptions, references=all_references
|
||||||
|
)
|
||||||
|
print("WER: %.3f" % wer)
|
||||||
6
docker/Dockerfile
Normal file
6
docker/Dockerfile
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04
|
||||||
|
WORKDIR /root
|
||||||
|
RUN apt-get update -y && apt-get install -y python3-pip
|
||||||
|
COPY infer.py jfk.flac ./
|
||||||
|
RUN pip3 install faster-whisper
|
||||||
|
CMD ["python3", "infer.py"]
|
||||||
7
docker/infer.py
Normal file
7
docker/infer.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
jfk_path = "jfk.flac"
|
||||||
|
model = WhisperModel("tiny", device="cuda")
|
||||||
|
segments, info = model.transcribe(jfk_path, word_timestamps=True)
|
||||||
|
for segment in segments:
|
||||||
|
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
|
||||||
BIN
docker/jfk.flac
Normal file
BIN
docker/jfk.flac
Normal file
Binary file not shown.
Binary file not shown.
@@ -105,6 +105,42 @@ class Tokenizer:
|
|||||||
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def non_speech_tokens(self) -> Tuple[int]:
|
||||||
|
"""
|
||||||
|
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||||
|
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||||
|
|
||||||
|
- ♪♪♪
|
||||||
|
- ( SPEAKING FOREIGN LANGUAGE )
|
||||||
|
- [DAVID] Hey there,
|
||||||
|
|
||||||
|
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||||
|
"""
|
||||||
|
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||||
|
symbols += (
|
||||||
|
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||||
|
)
|
||||||
|
|
||||||
|
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||||
|
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||||
|
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||||
|
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||||
|
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||||
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||||
|
|
||||||
|
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||||
|
result = {self.encode(" -")[0], self.encode(" '")[0]}
|
||||||
|
for symbol in symbols + list(miscellaneous):
|
||||||
|
for tokens in [
|
||||||
|
self.encode(symbol),
|
||||||
|
self.encode(" " + symbol),
|
||||||
|
]:
|
||||||
|
if len(tokens) == 1 or symbol in miscellaneous:
|
||||||
|
result.add(tokens[0])
|
||||||
|
|
||||||
|
return tuple(sorted(result))
|
||||||
|
|
||||||
def split_to_word_tokens(
|
def split_to_word_tokens(
|
||||||
self, tokens: List[int]
|
self, tokens: List[int]
|
||||||
) -> Tuple[List[str], List[List[int]]]:
|
) -> Tuple[List[str], List[List[int]]]:
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple):
|
|||||||
max_new_tokens: Optional[int]
|
max_new_tokens: Optional[int]
|
||||||
clip_timestamps: Union[str, List[float]]
|
clip_timestamps: Union[str, List[float]]
|
||||||
hallucination_silence_threshold: Optional[float]
|
hallucination_silence_threshold: Optional[float]
|
||||||
|
hotwords: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionInfo(NamedTuple):
|
class TranscriptionInfo(NamedTuple):
|
||||||
@@ -92,12 +93,15 @@ class WhisperModel:
|
|||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
download_root: Optional[str] = None,
|
download_root: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
files: dict = None,
|
||||||
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
"""Initializes the Whisper model.
|
"""Initializes the Whisper model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
|
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
|
||||||
small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a
|
small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
|
||||||
|
large-v2, large-v3, large, distil-large-v2 or distil-large-v3), a path to a
|
||||||
converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub.
|
converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub.
|
||||||
When a size or a model ID is configured, the converted model is downloaded
|
When a size or a model ID is configured, the converted model is downloaded
|
||||||
from the Hugging Face Hub.
|
from the Hugging Face Hub.
|
||||||
@@ -118,10 +122,18 @@ class WhisperModel:
|
|||||||
are saved in the standard Hugging Face cache directory.
|
are saved in the standard Hugging Face cache directory.
|
||||||
local_files_only: If True, avoid downloading the file and return the path to the
|
local_files_only: If True, avoid downloading the file and return the path to the
|
||||||
local cached file if it exists.
|
local cached file if it exists.
|
||||||
|
files: Load model files from the memory. This argument is a dictionary mapping file names
|
||||||
|
to file contents as file-like or bytes objects. If this is set, model_path acts as an
|
||||||
|
identifier for this model.
|
||||||
"""
|
"""
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
if os.path.isdir(model_size_or_path):
|
tokenizer_bytes, preprocessor_bytes = None, None
|
||||||
|
if files:
|
||||||
|
model_path = model_size_or_path
|
||||||
|
tokenizer_bytes = files.pop("tokenizer.json", None)
|
||||||
|
preprocessor_bytes = files.pop("preprocessor_config.json", None)
|
||||||
|
elif os.path.isdir(model_size_or_path):
|
||||||
model_path = model_size_or_path
|
model_path = model_size_or_path
|
||||||
else:
|
else:
|
||||||
model_path = download_model(
|
model_path = download_model(
|
||||||
@@ -137,17 +149,20 @@ class WhisperModel:
|
|||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
intra_threads=cpu_threads,
|
intra_threads=cpu_threads,
|
||||||
inter_threads=num_workers,
|
inter_threads=num_workers,
|
||||||
|
files=files,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
||||||
if os.path.isfile(tokenizer_file):
|
if tokenizer_bytes:
|
||||||
|
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
|
||||||
|
elif os.path.isfile(tokenizer_file):
|
||||||
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
||||||
else:
|
else:
|
||||||
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
||||||
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
||||||
)
|
)
|
||||||
|
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
|
||||||
self.feat_kwargs = self._get_feature_kwargs(model_path)
|
|
||||||
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
|
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
|
||||||
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
||||||
self.frames_per_second = (
|
self.frames_per_second = (
|
||||||
@@ -165,19 +180,21 @@ class WhisperModel:
|
|||||||
"""The languages supported by the model."""
|
"""The languages supported by the model."""
|
||||||
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
|
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
|
||||||
|
|
||||||
def _get_feature_kwargs(self, model_path) -> dict:
|
def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
|
||||||
preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json")
|
|
||||||
config = {}
|
config = {}
|
||||||
if os.path.isfile(preprocessor_config_file):
|
|
||||||
try:
|
try:
|
||||||
with open(preprocessor_config_file, "r", encoding="utf-8") as json_file:
|
config_path = os.path.join(model_path, "preprocessor_config.json")
|
||||||
config = json.load(json_file)
|
if preprocessor_bytes:
|
||||||
|
config = json.loads(preprocessor_bytes)
|
||||||
|
elif os.path.isfile(config_path):
|
||||||
|
with open(config_path, "r", encoding="utf-8") as file:
|
||||||
|
config = json.load(file)
|
||||||
|
else:
|
||||||
|
return config
|
||||||
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
|
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
|
||||||
config = {k: v for k, v in config.items() if k in valid_keys}
|
return {k: v for k, v in config.items() if k in valid_keys}
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
self.logger.warning(
|
self.logger.warning("Could not load preprocessor config: %s", e)
|
||||||
"Could not load preprocessor_config.json: %s", str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -220,6 +237,7 @@ class WhisperModel:
|
|||||||
chunk_length: Optional[int] = None,
|
chunk_length: Optional[int] = None,
|
||||||
clip_timestamps: Union[str, List[float]] = "0",
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
hallucination_silence_threshold: Optional[float] = None,
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
|
hotwords: Optional[str] = None,
|
||||||
language_detection_threshold: Optional[float] = None,
|
language_detection_threshold: Optional[float] = None,
|
||||||
language_detection_segments: int = 1,
|
language_detection_segments: int = 1,
|
||||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||||
@@ -259,7 +277,7 @@ class WhisperModel:
|
|||||||
prefix: Optional text to provide as a prefix for the first window.
|
prefix: Optional text to provide as a prefix for the first window.
|
||||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||||
of symbols as defined in the model config.json file.
|
of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||||
without_timestamps: Only sample text tokens.
|
without_timestamps: Only sample text tokens.
|
||||||
max_initial_timestamp: The initial timestamp cannot be later than this.
|
max_initial_timestamp: The initial timestamp cannot be later than this.
|
||||||
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
||||||
@@ -277,17 +295,18 @@ class WhisperModel:
|
|||||||
the maximum will be set by the default max_length.
|
the maximum will be set by the default max_length.
|
||||||
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
||||||
default chunk_length of the FeatureExtractor.
|
default chunk_length of the FeatureExtractor.
|
||||||
clip_timestamps: Union[str, List[float]]
|
clip_timestamps:
|
||||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
|
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
|
||||||
process. The last end timestamp defaults to the end of the file.
|
process. The last end timestamp defaults to the end of the file.
|
||||||
vad_filter will be ignored if clip_timestamps is used.
|
vad_filter will be ignored if clip_timestamps is used.
|
||||||
hallucination_silence_threshold: Optional[float]
|
hallucination_silence_threshold:
|
||||||
When word_timestamps is True, skip silent periods longer than this threshold
|
When word_timestamps is True, skip silent periods longer than this threshold
|
||||||
(in seconds) when a possible hallucination is detected
|
(in seconds) when a possible hallucination is detected
|
||||||
|
hotwords:
|
||||||
|
Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
||||||
language_detection_threshold: If the maximum probability of the language tokens is higher
|
language_detection_threshold: If the maximum probability of the language tokens is higher
|
||||||
than this value, the language is detected.
|
than this value, the language is detected.
|
||||||
language_detection_segments: Number of segments to consider for the language detection.
|
language_detection_segments: Number of segments to consider for the language detection.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple with:
|
A tuple with:
|
||||||
|
|
||||||
@@ -351,16 +370,27 @@ class WhisperModel:
|
|||||||
or language_detection_segments < 1
|
or language_detection_segments < 1
|
||||||
):
|
):
|
||||||
language_detection_segments = 1
|
language_detection_segments = 1
|
||||||
seek = 0
|
start_timestamp = (
|
||||||
detected_language_info = {}
|
float(clip_timestamps.split(",")[0])
|
||||||
|
if isinstance(clip_timestamps, str)
|
||||||
|
else clip_timestamps[0]
|
||||||
|
)
|
||||||
content_frames = (
|
content_frames = (
|
||||||
features.shape[-1] - self.feature_extractor.nb_max_frames
|
features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||||
)
|
)
|
||||||
while (
|
seek = (
|
||||||
seek <= content_frames
|
int(start_timestamp * self.frames_per_second)
|
||||||
and seek
|
if start_timestamp * self.frames_per_second < content_frames
|
||||||
< self.feature_extractor.nb_max_frames * language_detection_segments
|
else 0
|
||||||
):
|
)
|
||||||
|
end_frames = min(
|
||||||
|
seek
|
||||||
|
+ self.feature_extractor.nb_max_frames
|
||||||
|
* language_detection_segments,
|
||||||
|
content_frames,
|
||||||
|
)
|
||||||
|
detected_language_info = {}
|
||||||
|
while seek <= end_frames:
|
||||||
segment = features[
|
segment = features[
|
||||||
:, seek : seek + self.feature_extractor.nb_max_frames
|
:, seek : seek + self.feature_extractor.nb_max_frames
|
||||||
]
|
]
|
||||||
@@ -432,7 +462,11 @@ class WhisperModel:
|
|||||||
initial_prompt=initial_prompt,
|
initial_prompt=initial_prompt,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
suppress_blank=suppress_blank,
|
suppress_blank=suppress_blank,
|
||||||
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
suppress_tokens=(
|
||||||
|
get_suppressed_tokens(tokenizer, suppress_tokens)
|
||||||
|
if suppress_tokens
|
||||||
|
else suppress_tokens
|
||||||
|
),
|
||||||
without_timestamps=without_timestamps,
|
without_timestamps=without_timestamps,
|
||||||
max_initial_timestamp=max_initial_timestamp,
|
max_initial_timestamp=max_initial_timestamp,
|
||||||
word_timestamps=word_timestamps,
|
word_timestamps=word_timestamps,
|
||||||
@@ -441,6 +475,7 @@ class WhisperModel:
|
|||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
clip_timestamps=clip_timestamps,
|
clip_timestamps=clip_timestamps,
|
||||||
hallucination_silence_threshold=hallucination_silence_threshold,
|
hallucination_silence_threshold=hallucination_silence_threshold,
|
||||||
|
hotwords=hotwords,
|
||||||
)
|
)
|
||||||
|
|
||||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||||
@@ -457,7 +492,6 @@ class WhisperModel:
|
|||||||
vad_options=vad_parameters,
|
vad_options=vad_parameters,
|
||||||
all_language_probs=all_language_probs,
|
all_language_probs=all_language_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return segments, info
|
return segments, info
|
||||||
|
|
||||||
def generate_segments(
|
def generate_segments(
|
||||||
@@ -471,7 +505,8 @@ class WhisperModel:
|
|||||||
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
||||||
|
|
||||||
if isinstance(options.clip_timestamps, str):
|
if isinstance(options.clip_timestamps, str):
|
||||||
TranscriptionOptions.clip_timestamps = [
|
options = options._replace(
|
||||||
|
clip_timestamps=[
|
||||||
float(ts)
|
float(ts)
|
||||||
for ts in (
|
for ts in (
|
||||||
options.clip_timestamps.split(",")
|
options.clip_timestamps.split(",")
|
||||||
@@ -479,6 +514,7 @@ class WhisperModel:
|
|||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
)
|
||||||
seek_points: List[int] = [
|
seek_points: List[int] = [
|
||||||
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
||||||
]
|
]
|
||||||
@@ -496,6 +532,7 @@ class WhisperModel:
|
|||||||
clip_idx = 0
|
clip_idx = 0
|
||||||
seek = seek_clips[clip_idx][0]
|
seek = seek_clips[clip_idx][0]
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
|
all_prompt_text = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
if options.initial_prompt is not None:
|
if options.initial_prompt is not None:
|
||||||
@@ -547,6 +584,7 @@ class WhisperModel:
|
|||||||
previous_tokens,
|
previous_tokens,
|
||||||
without_timestamps=options.without_timestamps,
|
without_timestamps=options.without_timestamps,
|
||||||
prefix=options.prefix if seek == 0 else None,
|
prefix=options.prefix if seek == 0 else None,
|
||||||
|
hotwords=options.hotwords,
|
||||||
)
|
)
|
||||||
|
|
||||||
if seek > 0 or encoder_output is None:
|
if seek > 0 or encoder_output is None:
|
||||||
@@ -759,7 +797,15 @@ class WhisperModel:
|
|||||||
if segment["start"] == segment["end"] or not text.strip():
|
if segment["start"] == segment["end"] or not text.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
check_prompt_num = 1
|
||||||
|
if all(
|
||||||
|
[
|
||||||
|
text.strip() != i.strip()
|
||||||
|
for i in all_prompt_text[-check_prompt_num:]
|
||||||
|
]
|
||||||
|
):
|
||||||
all_tokens.extend(tokens)
|
all_tokens.extend(tokens)
|
||||||
|
all_prompt_text.append(text)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
yield Segment(
|
yield Segment(
|
||||||
@@ -939,11 +985,18 @@ class WhisperModel:
|
|||||||
previous_tokens: List[int],
|
previous_tokens: List[int],
|
||||||
without_timestamps: bool = False,
|
without_timestamps: bool = False,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
|
hotwords: Optional[str] = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
if previous_tokens:
|
if previous_tokens or (hotwords and not prefix):
|
||||||
prompt.append(tokenizer.sot_prev)
|
prompt.append(tokenizer.sot_prev)
|
||||||
|
if hotwords and not prefix:
|
||||||
|
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
|
||||||
|
if len(hotwords_tokens) >= self.max_length // 2:
|
||||||
|
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
|
||||||
|
prompt.extend(hotwords_tokens)
|
||||||
|
if previous_tokens:
|
||||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||||
|
|
||||||
prompt.extend(tokenizer.sot_sequence)
|
prompt.extend(tokenizer.sot_sequence)
|
||||||
@@ -1186,15 +1239,16 @@ def get_compression_ratio(text: str) -> float:
|
|||||||
|
|
||||||
def get_suppressed_tokens(
|
def get_suppressed_tokens(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
suppress_tokens: Optional[List[int]],
|
suppress_tokens: Tuple[int],
|
||||||
) -> Optional[List[int]]:
|
) -> Optional[List[int]]:
|
||||||
if not suppress_tokens or -1 in suppress_tokens:
|
if -1 in suppress_tokens:
|
||||||
return suppress_tokens
|
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||||
|
suppress_tokens.extend(tokenizer.non_speech_tokens)
|
||||||
|
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||||
|
suppress_tokens = [] # interpret empty string as an empty list
|
||||||
|
else:
|
||||||
|
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||||
|
|
||||||
suppress_tokens = list(suppress_tokens)
|
|
||||||
|
|
||||||
# Ensure the following special tokens are suppressed when the user does
|
|
||||||
# not use the default set (-1).
|
|
||||||
suppress_tokens.extend(
|
suppress_tokens.extend(
|
||||||
[
|
[
|
||||||
tokenizer.transcribe,
|
tokenizer.transcribe,
|
||||||
@@ -1205,7 +1259,7 @@ def get_suppressed_tokens(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return sorted(set(suppress_tokens))
|
return tuple(sorted(set(suppress_tokens)))
|
||||||
|
|
||||||
|
|
||||||
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
||||||
|
|||||||
@@ -54,8 +54,9 @@ def download_model(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
size_or_id: Size of the model to download from https://huggingface.co/Systran
|
size_or_id: Size of the model to download from https://huggingface.co/Systran
|
||||||
(tiny, tiny.en, base, base.en, small, small.en medium, medium.en, large-v1, large-v2,
|
(tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en,
|
||||||
large-v3, large), or a CTranslate2-converted model ID from the Hugging Face Hub
|
distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2,
|
||||||
|
distil-large-v3), or a CTranslate2-converted model ID from the Hugging Face Hub
|
||||||
(e.g. Systran/faster-whisper-large-v3).
|
(e.g. Systran/faster-whisper-large-v3).
|
||||||
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
||||||
the cache directory.
|
the cache directory.
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import bisect
|
import bisect
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
|
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
@@ -25,9 +24,6 @@ class VadOptions(NamedTuple):
|
|||||||
split aggressively just before max_speech_duration_s.
|
split aggressively just before max_speech_duration_s.
|
||||||
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
||||||
before separating it
|
before separating it
|
||||||
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
|
|
||||||
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
|
||||||
Values other than these may affect model performance!!
|
|
||||||
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -35,7 +31,6 @@ class VadOptions(NamedTuple):
|
|||||||
min_speech_duration_ms: int = 250
|
min_speech_duration_ms: int = 250
|
||||||
max_speech_duration_s: float = float("inf")
|
max_speech_duration_s: float = float("inf")
|
||||||
min_silence_duration_ms: int = 2000
|
min_silence_duration_ms: int = 2000
|
||||||
window_size_samples: int = 1024
|
|
||||||
speech_pad_ms: int = 400
|
speech_pad_ms: int = 400
|
||||||
|
|
||||||
|
|
||||||
@@ -61,15 +56,8 @@ def get_speech_timestamps(
|
|||||||
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
||||||
max_speech_duration_s = vad_options.max_speech_duration_s
|
max_speech_duration_s = vad_options.max_speech_duration_s
|
||||||
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
||||||
window_size_samples = vad_options.window_size_samples
|
window_size_samples = 512
|
||||||
speech_pad_ms = vad_options.speech_pad_ms
|
speech_pad_ms = vad_options.speech_pad_ms
|
||||||
|
|
||||||
if window_size_samples not in [512, 1024, 1536]:
|
|
||||||
warnings.warn(
|
|
||||||
"Unusual window_size_samples! Supported window_size_samples:\n"
|
|
||||||
" - [512, 1024, 1536] for 16000 sampling_rate"
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_rate = 16000
|
sampling_rate = 16000
|
||||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||||
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||||
@@ -84,14 +72,14 @@ def get_speech_timestamps(
|
|||||||
audio_length_samples = len(audio)
|
audio_length_samples = len(audio)
|
||||||
|
|
||||||
model = get_vad_model()
|
model = get_vad_model()
|
||||||
state = model.get_initial_state(batch_size=1)
|
state, context = model.get_initial_states(batch_size=1)
|
||||||
|
|
||||||
speech_probs = []
|
speech_probs = []
|
||||||
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
||||||
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
|
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
|
||||||
if len(chunk) < window_size_samples:
|
if len(chunk) < window_size_samples:
|
||||||
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||||
speech_prob, state = model(chunk, state, sampling_rate)
|
speech_prob, state, context = model(chunk, state, context, sampling_rate)
|
||||||
speech_probs.append(speech_prob)
|
speech_probs.append(speech_prob)
|
||||||
|
|
||||||
triggered = False
|
triggered = False
|
||||||
@@ -261,12 +249,12 @@ class SileroVADModel:
|
|||||||
sess_options=opts,
|
sess_options=opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_initial_state(self, batch_size: int):
|
def get_initial_states(self, batch_size: int):
|
||||||
h = np.zeros((2, batch_size, 64), dtype=np.float32)
|
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
||||||
c = np.zeros((2, batch_size, 64), dtype=np.float32)
|
context = np.zeros((batch_size, 64), dtype=np.float32)
|
||||||
return h, c
|
return state, context
|
||||||
|
|
||||||
def __call__(self, x, state, sr: int):
|
def __call__(self, x, state, context, sr: int):
|
||||||
if len(x.shape) == 1:
|
if len(x.shape) == 1:
|
||||||
x = np.expand_dims(x, 0)
|
x = np.expand_dims(x, 0)
|
||||||
if len(x.shape) > 2:
|
if len(x.shape) > 2:
|
||||||
@@ -276,16 +264,15 @@ class SileroVADModel:
|
|||||||
if sr / x.shape[1] > 31.25:
|
if sr / x.shape[1] > 31.25:
|
||||||
raise ValueError("Input audio chunk is too short")
|
raise ValueError("Input audio chunk is too short")
|
||||||
|
|
||||||
h, c = state
|
x = np.concatenate([context, x], axis=1)
|
||||||
|
|
||||||
ort_inputs = {
|
ort_inputs = {
|
||||||
"input": x,
|
"input": x,
|
||||||
"h": h,
|
"state": state,
|
||||||
"c": c,
|
|
||||||
"sr": np.array(sr, dtype="int64"),
|
"sr": np.array(sr, dtype="int64"),
|
||||||
}
|
}
|
||||||
|
|
||||||
out, h, c = self.session.run(None, ort_inputs)
|
out, state = self.session.run(None, ort_inputs)
|
||||||
state = (h, c)
|
context = x[..., -64:]
|
||||||
|
|
||||||
return out, state
|
return out, state, context
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
"""Version information."""
|
"""Version information."""
|
||||||
|
|
||||||
__version__ = "1.0.1"
|
__version__ = "1.0.3"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
av==11.*
|
av>=11.0,<13
|
||||||
ctranslate2>=4.0,<5
|
ctranslate2>=4.0,<5
|
||||||
huggingface_hub>=0.13
|
huggingface_hub>=0.13
|
||||||
tokenizers>=0.13,<0.16
|
tokenizers>=0.13,<1
|
||||||
onnxruntime>=1.14,<2
|
onnxruntime>=1.14,<2
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from faster_whisper import WhisperModel, decode_audio
|
from faster_whisper import WhisperModel, decode_audio
|
||||||
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
from faster_whisper.transcribe import get_suppressed_tokens
|
||||||
|
|
||||||
|
|
||||||
def test_supported_languages():
|
def test_supported_languages():
|
||||||
@@ -97,3 +99,109 @@ def test_stereo_diarization(data_dir):
|
|||||||
segments, _ = model.transcribe(right)
|
segments, _ = model.transcribe(right)
|
||||||
transcription = "".join(segment.text for segment in segments).strip()
|
transcription = "".join(segment.text for segment in segments).strip()
|
||||||
assert transcription == "The horizon seems extremely distant."
|
assert transcription == "The horizon seems extremely distant."
|
||||||
|
|
||||||
|
|
||||||
|
def test_suppressed_tokens_minus_1():
|
||||||
|
model = WhisperModel("tiny.en")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(model.hf_tokenizer, False)
|
||||||
|
tokens = get_suppressed_tokens(tokenizer, [-1])
|
||||||
|
assert tokens == (
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
14,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
31,
|
||||||
|
58,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
62,
|
||||||
|
63,
|
||||||
|
90,
|
||||||
|
91,
|
||||||
|
92,
|
||||||
|
93,
|
||||||
|
357,
|
||||||
|
366,
|
||||||
|
438,
|
||||||
|
532,
|
||||||
|
685,
|
||||||
|
705,
|
||||||
|
796,
|
||||||
|
930,
|
||||||
|
1058,
|
||||||
|
1220,
|
||||||
|
1267,
|
||||||
|
1279,
|
||||||
|
1303,
|
||||||
|
1343,
|
||||||
|
1377,
|
||||||
|
1391,
|
||||||
|
1635,
|
||||||
|
1782,
|
||||||
|
1875,
|
||||||
|
2162,
|
||||||
|
2361,
|
||||||
|
2488,
|
||||||
|
3467,
|
||||||
|
4008,
|
||||||
|
4211,
|
||||||
|
4600,
|
||||||
|
4808,
|
||||||
|
5299,
|
||||||
|
5855,
|
||||||
|
6329,
|
||||||
|
7203,
|
||||||
|
9609,
|
||||||
|
9959,
|
||||||
|
10563,
|
||||||
|
10786,
|
||||||
|
11420,
|
||||||
|
11709,
|
||||||
|
11907,
|
||||||
|
13163,
|
||||||
|
13697,
|
||||||
|
13700,
|
||||||
|
14808,
|
||||||
|
15306,
|
||||||
|
16410,
|
||||||
|
16791,
|
||||||
|
17992,
|
||||||
|
19203,
|
||||||
|
19510,
|
||||||
|
20724,
|
||||||
|
22305,
|
||||||
|
22935,
|
||||||
|
27007,
|
||||||
|
30109,
|
||||||
|
30420,
|
||||||
|
33409,
|
||||||
|
34949,
|
||||||
|
40283,
|
||||||
|
40493,
|
||||||
|
40549,
|
||||||
|
47282,
|
||||||
|
49146,
|
||||||
|
50257,
|
||||||
|
50357,
|
||||||
|
50358,
|
||||||
|
50359,
|
||||||
|
50360,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_suppressed_tokens_minus_value():
|
||||||
|
model = WhisperModel("tiny.en")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(model.hf_tokenizer, False)
|
||||||
|
tokens = get_suppressed_tokens(tokenizer, [13])
|
||||||
|
assert tokens == (13, 50257, 50357, 50358, 50359, 50360)
|
||||||
|
|||||||
Reference in New Issue
Block a user