Expose the device_index argument (#5)
This commit is contained in:
@@ -43,6 +43,7 @@ class WhisperModel:
|
||||
self,
|
||||
model_path,
|
||||
device="auto",
|
||||
device_index=0,
|
||||
compute_type="default",
|
||||
cpu_threads=0,
|
||||
):
|
||||
@@ -51,6 +52,7 @@ class WhisperModel:
|
||||
Args:
|
||||
model_path: Path to the converted model.
|
||||
device: Device to use for computation ("cpu", "cuda", "auto").
|
||||
device_index: Device ID to use.
|
||||
compute_type: Type to use for computation.
|
||||
See https://opennmt.net/CTranslate2/quantization.html.
|
||||
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
||||
@@ -59,6 +61,7 @@ class WhisperModel:
|
||||
self.model = ctranslate2.models.Whisper(
|
||||
model_path,
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
compute_type=compute_type,
|
||||
intra_threads=cpu_threads,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user