Add typing to constructor and transcribe method

This commit is contained in:
Guillaume Klein
2023-02-27 11:22:02 +01:00
parent b1c69927f8
commit f0add58bdc

View File

@@ -1,6 +1,8 @@
import collections
import zlib
from typing import BinaryIO, List, Optional, Tuple, Union
import ctranslate2
import numpy as np
import tokenizers
@@ -44,12 +46,12 @@ class TranscriptionOptions(
class WhisperModel:
def __init__(
self,
model_path,
device="auto",
device_index=0,
compute_type="default",
cpu_threads=0,
num_workers=1,
model_path: str,
device: str = "auto",
device_index: int = 0,
compute_type: str = "default",
cpu_threads: int = 0,
num_workers: int = 1,
):
"""Initializes the Whisper model.
@@ -90,20 +92,27 @@ class WhisperModel:
def transcribe(
self,
input_file,
language=None,
task="transcribe",
beam_size=5,
best_of=5,
patience=1,
length_penalty=1,
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
condition_on_previous_text=True,
initial_prompt=None,
without_timestamps=False,
input_file: Union[str, BinaryIO],
language: Optional[str] = None,
task: str = "transcribe",
beam_size: int = 5,
best_of: int = 5,
patience: float = 1,
length_penalty: float = 1,
temperature: Union[float, List[float], Tuple[float, ...]] = [
0.0,
0.2,
0.4,
0.6,
0.8,
1.0,
],
compression_ratio_threshold: float = 2.4,
log_prob_threshold: float = -1.0,
no_speech_threshold: float = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
without_timestamps: bool = False,
):
"""Transcribes an input file.