Add more typing annotations

This commit is contained in:
Guillaume Klein
2023-03-15 15:22:53 +01:00
parent 8bd013ea99
commit eafb2c79a3
2 changed files with 62 additions and 50 deletions

View File

@@ -6,13 +6,16 @@ system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly.
"""
import av
import io
import itertools
from typing import BinaryIO, Union
import av
import numpy as np
def decode_audio(input_file, sampling_rate=16000):
def decode_audio(input_file: Union[str, BinaryIO], sampling_rate: int = 16000):
"""Decodes the audio.
Args:

View File

@@ -3,7 +3,7 @@ import itertools
import os
import zlib
from typing import BinaryIO, List, Optional, Tuple, Union
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
import ctranslate2
import numpy as np
@@ -14,46 +14,44 @@ from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer
class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))):
pass
class Word(NamedTuple):
start: float
end: float
word: str
probability: float
class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))):
pass
class Segment(NamedTuple):
start: float
end: float
text: str
words: List[Word]
class AudioInfo(
collections.namedtuple("AudioInfo", ("language", "language_probability"))
):
pass
class AudioInfo(NamedTuple):
language: str
language_probability: float
class TranscriptionOptions(
collections.namedtuple(
"TranscriptionOptions",
(
"beam_size",
"best_of",
"patience",
"length_penalty",
"log_prob_threshold",
"no_speech_threshold",
"compression_ratio_threshold",
"condition_on_previous_text",
"temperatures",
"initial_prompt",
"prefix",
"suppress_blank",
"suppress_tokens",
"without_timestamps",
"max_initial_timestamp",
"word_timestamps",
"prepend_punctuations",
"append_punctuations",
),
)
):
pass
class TranscriptionOptions(NamedTuple):
beam_size: int
best_of: int
patience: float
length_penalty: float
log_prob_threshold: Optional[float]
no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
temperatures: List[float]
initial_prompt: Optional[str]
prefix: Optional[str]
suppress_blank: bool
suppress_tokens: Optional[List[int]]
without_timestamps: bool
max_initial_timestamp: float
word_timestamps: bool
prepend_punctuations: str
append_punctuations: str
class WhisperModel:
@@ -143,7 +141,7 @@ class WhisperModel:
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
):
) -> Tuple[Iterable[Segment], AudioInfo]:
"""Transcribes an input file.
Arguments:
@@ -203,7 +201,7 @@ class WhisperModel:
language_probability = 1
else:
segment = features[:, : self.feature_extractor.nb_max_frames]
input = get_input(segment)
input = get_ctranslate2_storage(segment)
results = self.model.detect_language(input)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
@@ -249,7 +247,12 @@ class WhisperModel:
return segments, audio_info
def generate_segments(self, features, tokenizer, options):
def generate_segments(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
seek = 0
all_tokens = []
@@ -421,8 +424,14 @@ class WhisperModel:
),
)
def generate_with_fallback(self, segment, prompt, tokenizer, options):
features = get_input(segment)
def generate_with_fallback(
self,
segment: np.ndarray,
prompt: List[int],
tokenizer: Tokenizer,
options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
features = get_ctranslate2_storage(segment)
result = None
avg_log_prob = None
final_temperature = None
@@ -490,11 +499,11 @@ class WhisperModel:
def get_prompt(
self,
tokenizer,
previous_tokens,
without_timestamps=False,
prefix=None,
):
tokenizer: Tokenizer,
previous_tokens: List[int],
without_timestamps: bool = False,
prefix: Optional[str] = None,
) -> List[int]:
prompt = []
if previous_tokens:
@@ -582,7 +591,7 @@ class WhisperModel:
return []
result = self.model.align(
get_input(mel),
get_ctranslate2_storage(mel),
tokenizer.sot_sequence,
[text_tokens],
num_frames,
@@ -635,14 +644,14 @@ class WhisperModel:
]
def get_input(segment):
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
segment = np.ascontiguousarray(segment)
segment = np.expand_dims(segment, 0)
segment = ctranslate2.StorageView.from_array(segment)
return segment
def get_compression_ratio(text):
def get_compression_ratio(text: str) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))