Add typing to constructor and transcribe method
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import collections
|
||||
import zlib
|
||||
|
||||
from typing import BinaryIO, List, Optional, Tuple, Union
|
||||
|
||||
import ctranslate2
|
||||
import numpy as np
|
||||
import tokenizers
|
||||
@@ -44,12 +46,12 @@ class TranscriptionOptions(
|
||||
class WhisperModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
device="auto",
|
||||
device_index=0,
|
||||
compute_type="default",
|
||||
cpu_threads=0,
|
||||
num_workers=1,
|
||||
model_path: str,
|
||||
device: str = "auto",
|
||||
device_index: int = 0,
|
||||
compute_type: str = "default",
|
||||
cpu_threads: int = 0,
|
||||
num_workers: int = 1,
|
||||
):
|
||||
"""Initializes the Whisper model.
|
||||
|
||||
@@ -90,20 +92,27 @@ class WhisperModel:
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
input_file,
|
||||
language=None,
|
||||
task="transcribe",
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
patience=1,
|
||||
length_penalty=1,
|
||||
temperature=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||||
compression_ratio_threshold=2.4,
|
||||
log_prob_threshold=-1.0,
|
||||
no_speech_threshold=0.6,
|
||||
condition_on_previous_text=True,
|
||||
initial_prompt=None,
|
||||
without_timestamps=False,
|
||||
input_file: Union[str, BinaryIO],
|
||||
language: Optional[str] = None,
|
||||
task: str = "transcribe",
|
||||
beam_size: int = 5,
|
||||
best_of: int = 5,
|
||||
patience: float = 1,
|
||||
length_penalty: float = 1,
|
||||
temperature: Union[float, List[float], Tuple[float, ...]] = [
|
||||
0.0,
|
||||
0.2,
|
||||
0.4,
|
||||
0.6,
|
||||
0.8,
|
||||
1.0,
|
||||
],
|
||||
compression_ratio_threshold: float = 2.4,
|
||||
log_prob_threshold: float = -1.0,
|
||||
no_speech_threshold: float = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
without_timestamps: bool = False,
|
||||
):
|
||||
"""Transcribes an input file.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user