Expose the device_index argument (#5)
This commit is contained in:
@@ -43,6 +43,7 @@ class WhisperModel:
|
|||||||
self,
|
self,
|
||||||
model_path,
|
model_path,
|
||||||
device="auto",
|
device="auto",
|
||||||
|
device_index=0,
|
||||||
compute_type="default",
|
compute_type="default",
|
||||||
cpu_threads=0,
|
cpu_threads=0,
|
||||||
):
|
):
|
||||||
@@ -51,6 +52,7 @@ class WhisperModel:
|
|||||||
Args:
|
Args:
|
||||||
model_path: Path to the converted model.
|
model_path: Path to the converted model.
|
||||||
device: Device to use for computation ("cpu", "cuda", "auto").
|
device: Device to use for computation ("cpu", "cuda", "auto").
|
||||||
|
device_index: Device ID to use.
|
||||||
compute_type: Type to use for computation.
|
compute_type: Type to use for computation.
|
||||||
See https://opennmt.net/CTranslate2/quantization.html.
|
See https://opennmt.net/CTranslate2/quantization.html.
|
||||||
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
||||||
@@ -59,6 +61,7 @@ class WhisperModel:
|
|||||||
self.model = ctranslate2.models.Whisper(
|
self.model = ctranslate2.models.Whisper(
|
||||||
model_path,
|
model_path,
|
||||||
device=device,
|
device=device,
|
||||||
|
device_index=device_index,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
intra_threads=cpu_threads,
|
intra_threads=cpu_threads,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user