Add typing to constructor and transcribe method

This commit is contained in:
Guillaume Klein
2023-02-27 11:22:02 +01:00
parent b1c69927f8
commit f0add58bdc

View File

@@ -1,6 +1,8 @@
import collections import collections
import zlib import zlib
from typing import BinaryIO, List, Optional, Tuple, Union
import ctranslate2 import ctranslate2
import numpy as np import numpy as np
import tokenizers import tokenizers
@@ -44,12 +46,12 @@ class TranscriptionOptions(
class WhisperModel: class WhisperModel:
def __init__( def __init__(
self, self,
model_path, model_path: str,
device="auto", device: str = "auto",
device_index=0, device_index: int = 0,
compute_type="default", compute_type: str = "default",
cpu_threads=0, cpu_threads: int = 0,
num_workers=1, num_workers: int = 1,
): ):
"""Initializes the Whisper model. """Initializes the Whisper model.
@@ -90,20 +92,27 @@ class WhisperModel:
def transcribe( def transcribe(
self, self,
input_file, input_file: Union[str, BinaryIO],
language=None, language: Optional[str] = None,
task="transcribe", task: str = "transcribe",
beam_size=5, beam_size: int = 5,
best_of=5, best_of: int = 5,
patience=1, patience: float = 1,
length_penalty=1, length_penalty: float = 1,
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0], temperature: Union[float, List[float], Tuple[float, ...]] = [
compression_ratio_threshold=2.4, 0.0,
log_prob_threshold=-1.0, 0.2,
no_speech_threshold=0.6, 0.4,
condition_on_previous_text=True, 0.6,
initial_prompt=None, 0.8,
without_timestamps=False, 1.0,
],
compression_ratio_threshold: float = 2.4,
log_prob_threshold: float = -1.0,
no_speech_threshold: float = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
without_timestamps: bool = False,
): ):
"""Transcribes an input file. """Transcribes an input file.