diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 90e898f..0e49429 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -43,6 +43,7 @@ class WhisperModel: self, model_path, device="auto", + device_index=0, compute_type="default", cpu_threads=0, ): @@ -51,6 +52,7 @@ class WhisperModel: Args: model_path: Path to the converted model. device: Device to use for computation ("cpu", "cuda", "auto"). + device_index: Device ID to use. compute_type: Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html. cpu_threads: Number of threads to use when running on CPU (4 by default). @@ -59,6 +61,7 @@ class WhisperModel: self.model = ctranslate2.models.Whisper( model_path, device=device, + device_index=device_index, compute_type=compute_type, intra_threads=cpu_threads, )