Expose the device_index argument (#5)

This commit is contained in:
Guillaume Klein
2023-02-13 11:06:40 +01:00
committed by GitHub
parent 0bcbbfa8c2
commit 269b3dfb10

View File

@@ -43,6 +43,7 @@ class WhisperModel:
self,
model_path,
device="auto",
device_index=0,
compute_type="default",
cpu_threads=0,
):
@@ -51,6 +52,7 @@ class WhisperModel:
Args:
model_path: Path to the converted model.
device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use.
compute_type: Type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
cpu_threads: Number of threads to use when running on CPU (4 by default).
@@ -59,6 +61,7 @@ class WhisperModel:
self.model = ctranslate2.models.Whisper(
model_path,
device=device,
device_index=device_index,
compute_type=compute_type,
intra_threads=cpu_threads,
)