Add num_workers parameter

This commit is contained in:
Guillaume Klein
2023-02-14 09:34:05 +01:00
parent c86353d323
commit cbbe633082
2 changed files with 11 additions and 0 deletions

View File

@@ -49,6 +49,7 @@ class WhisperModel:
device_index=0,
compute_type="default",
cpu_threads=0,
num_workers=1,
):
"""Initializes the Whisper model.
@@ -56,10 +57,17 @@ class WhisperModel:
model_path: Path to the converted model.
device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use.
The model can also be loaded on multiple GPUs by passing a list of IDs
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
when transcribe() is called from multiple Python threads (see also num_workers).
compute_type: Type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
cpu_threads: Number of threads to use when running on CPU (4 by default).
A non zero value overrides the OMP_NUM_THREADS environment variable.
num_workers: When transcribe() is called from multiple Python threads,
having multiple workers enables true parallelism when running the model
(concurrent calls to self.model.generate() will run in parallel).
This can improve the global throughput at the cost of increased memory usage.
"""
self.model = ctranslate2.models.Whisper(
model_path,
@@ -67,6 +75,7 @@ class WhisperModel:
device_index=device_index,
compute_type=compute_type,
intra_threads=cpu_threads,
inter_threads=num_workers,
)
with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file: