Add more typing annotations

This commit is contained in:
Guillaume Klein
2023-03-15 15:22:53 +01:00
parent 8bd013ea99
commit eafb2c79a3
2 changed files with 62 additions and 50 deletions

View File

@@ -6,13 +6,16 @@ system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly. However, the API is quite low-level so we need to manipulate audio frames directly.
""" """
import av
import io import io
import itertools import itertools
from typing import BinaryIO, Union
import av
import numpy as np import numpy as np
def decode_audio(input_file, sampling_rate=16000): def decode_audio(input_file: Union[str, BinaryIO], sampling_rate: int = 16000):
"""Decodes the audio. """Decodes the audio.
Args: Args:

View File

@@ -3,7 +3,7 @@ import itertools
import os import os
import zlib import zlib
from typing import BinaryIO, List, Optional, Tuple, Union from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
import ctranslate2 import ctranslate2
import numpy as np import numpy as np
@@ -14,46 +14,44 @@ from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import Tokenizer from faster_whisper.tokenizer import Tokenizer
class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))): class Word(NamedTuple):
pass start: float
end: float
word: str
probability: float
class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))): class Segment(NamedTuple):
pass start: float
end: float
text: str
words: List[Word]
class AudioInfo( class AudioInfo(NamedTuple):
collections.namedtuple("AudioInfo", ("language", "language_probability")) language: str
): language_probability: float
pass
class TranscriptionOptions( class TranscriptionOptions(NamedTuple):
collections.namedtuple( beam_size: int
"TranscriptionOptions", best_of: int
( patience: float
"beam_size", length_penalty: float
"best_of", log_prob_threshold: Optional[float]
"patience", no_speech_threshold: Optional[float]
"length_penalty", compression_ratio_threshold: Optional[float]
"log_prob_threshold", condition_on_previous_text: bool
"no_speech_threshold", temperatures: List[float]
"compression_ratio_threshold", initial_prompt: Optional[str]
"condition_on_previous_text", prefix: Optional[str]
"temperatures", suppress_blank: bool
"initial_prompt", suppress_tokens: Optional[List[int]]
"prefix", without_timestamps: bool
"suppress_blank", max_initial_timestamp: float
"suppress_tokens", word_timestamps: bool
"without_timestamps", prepend_punctuations: str
"max_initial_timestamp", append_punctuations: str
"word_timestamps",
"prepend_punctuations",
"append_punctuations",
),
)
):
pass
class WhisperModel: class WhisperModel:
@@ -143,7 +141,7 @@ class WhisperModel:
word_timestamps: bool = False, word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-", prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、", append_punctuations: str = "\"'.。,!?::”)]}、",
): ) -> Tuple[Iterable[Segment], AudioInfo]:
"""Transcribes an input file. """Transcribes an input file.
Arguments: Arguments:
@@ -203,7 +201,7 @@ class WhisperModel:
language_probability = 1 language_probability = 1
else: else:
segment = features[:, : self.feature_extractor.nb_max_frames] segment = features[:, : self.feature_extractor.nb_max_frames]
input = get_input(segment) input = get_ctranslate2_storage(segment)
results = self.model.detect_language(input) results = self.model.detect_language(input)
language_token, language_probability = results[0][0] language_token, language_probability = results[0][0]
language = language_token[2:-2] language = language_token[2:-2]
@@ -249,7 +247,12 @@ class WhisperModel:
return segments, audio_info return segments, audio_info
def generate_segments(self, features, tokenizer, options): def generate_segments(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
seek = 0 seek = 0
all_tokens = [] all_tokens = []
@@ -421,8 +424,14 @@ class WhisperModel:
), ),
) )
def generate_with_fallback(self, segment, prompt, tokenizer, options): def generate_with_fallback(
features = get_input(segment) self,
segment: np.ndarray,
prompt: List[int],
tokenizer: Tokenizer,
options: TranscriptionOptions,
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
features = get_ctranslate2_storage(segment)
result = None result = None
avg_log_prob = None avg_log_prob = None
final_temperature = None final_temperature = None
@@ -490,11 +499,11 @@ class WhisperModel:
def get_prompt( def get_prompt(
self, self,
tokenizer, tokenizer: Tokenizer,
previous_tokens, previous_tokens: List[int],
without_timestamps=False, without_timestamps: bool = False,
prefix=None, prefix: Optional[str] = None,
): ) -> List[int]:
prompt = [] prompt = []
if previous_tokens: if previous_tokens:
@@ -582,7 +591,7 @@ class WhisperModel:
return [] return []
result = self.model.align( result = self.model.align(
get_input(mel), get_ctranslate2_storage(mel),
tokenizer.sot_sequence, tokenizer.sot_sequence,
[text_tokens], [text_tokens],
num_frames, num_frames,
@@ -635,14 +644,14 @@ class WhisperModel:
] ]
def get_input(segment): def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
segment = np.ascontiguousarray(segment) segment = np.ascontiguousarray(segment)
segment = np.expand_dims(segment, 0) segment = np.expand_dims(segment, 0)
segment = ctranslate2.StorageView.from_array(segment) segment = ctranslate2.StorageView.from_array(segment)
return segment return segment
def get_compression_ratio(text): def get_compression_ratio(text: str) -> float:
text_bytes = text.encode("utf-8") text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes)) return len(text_bytes) / len(zlib.compress(text_bytes))