Expose the device_index argument (#5)

This commit is contained in:
Guillaume Klein
2023-02-13 11:06:40 +01:00
committed by GitHub
parent 0bcbbfa8c2
commit 269b3dfb10

View File

@@ -43,6 +43,7 @@ class WhisperModel:
self, self,
model_path, model_path,
device="auto", device="auto",
device_index=0,
compute_type="default", compute_type="default",
cpu_threads=0, cpu_threads=0,
): ):
@@ -51,6 +52,7 @@ class WhisperModel:
Args: Args:
model_path: Path to the converted model. model_path: Path to the converted model.
device: Device to use for computation ("cpu", "cuda", "auto"). device: Device to use for computation ("cpu", "cuda", "auto").
device_index: Device ID to use.
compute_type: Type to use for computation. compute_type: Type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html. See https://opennmt.net/CTranslate2/quantization.html.
cpu_threads: Number of threads to use when running on CPU (4 by default). cpu_threads: Number of threads to use when running on CPU (4 by default).
@@ -59,6 +61,7 @@ class WhisperModel:
self.model = ctranslate2.models.Whisper( self.model = ctranslate2.models.Whisper(
model_path, model_path,
device=device, device=device,
device_index=device_index,
compute_type=compute_type, compute_type=compute_type,
intra_threads=cpu_threads, intra_threads=cpu_threads,
) )