diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index 0b7dfae..8d176d7 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -6,13 +6,16 @@ system dependencies. FFmpeg does not need to be installed on the system. However, the API is quite low-level so we need to manipulate audio frames directly. """ -import av import io import itertools + +from typing import BinaryIO, Union + +import av import numpy as np -def decode_audio(input_file, sampling_rate=16000): +def decode_audio(input_file: Union[str, BinaryIO], sampling_rate: int = 16000): """Decodes the audio. Args: diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index dba3402..25ef989 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -3,7 +3,7 @@ import itertools import os import zlib -from typing import BinaryIO, List, Optional, Tuple, Union +from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union import ctranslate2 import numpy as np @@ -14,46 +14,44 @@ from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import Tokenizer -class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))): - pass +class Word(NamedTuple): + start: float + end: float + word: str + probability: float -class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))): - pass +class Segment(NamedTuple): + start: float + end: float + text: str + words: List[Word] -class AudioInfo( - collections.namedtuple("AudioInfo", ("language", "language_probability")) -): - pass +class AudioInfo(NamedTuple): + language: str + language_probability: float -class TranscriptionOptions( - collections.namedtuple( - "TranscriptionOptions", - ( - "beam_size", - "best_of", - "patience", - "length_penalty", - "log_prob_threshold", - "no_speech_threshold", - "compression_ratio_threshold", - "condition_on_previous_text", - "temperatures", - "initial_prompt", - "prefix", - "suppress_blank", - "suppress_tokens", - "without_timestamps", - "max_initial_timestamp", - "word_timestamps", - "prepend_punctuations", - "append_punctuations", - ), - ) -): - pass +class TranscriptionOptions(NamedTuple): + beam_size: int + best_of: int + patience: float + length_penalty: float + log_prob_threshold: Optional[float] + no_speech_threshold: Optional[float] + compression_ratio_threshold: Optional[float] + condition_on_previous_text: bool + temperatures: List[float] + initial_prompt: Optional[str] + prefix: Optional[str] + suppress_blank: bool + suppress_tokens: Optional[List[int]] + without_timestamps: bool + max_initial_timestamp: float + word_timestamps: bool + prepend_punctuations: str + append_punctuations: str class WhisperModel: @@ -143,7 +141,7 @@ class WhisperModel: word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", - ): + ) -> Tuple[Iterable[Segment], AudioInfo]: """Transcribes an input file. Arguments: @@ -203,7 +201,7 @@ class WhisperModel: language_probability = 1 else: segment = features[:, : self.feature_extractor.nb_max_frames] - input = get_input(segment) + input = get_ctranslate2_storage(segment) results = self.model.detect_language(input) language_token, language_probability = results[0][0] language = language_token[2:-2] @@ -249,7 +247,12 @@ class WhisperModel: return segments, audio_info - def generate_segments(self, features, tokenizer, options): + def generate_segments( + self, + features: np.ndarray, + tokenizer: Tokenizer, + options: TranscriptionOptions, + ) -> Iterable[Segment]: content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames seek = 0 all_tokens = [] @@ -421,8 +424,14 @@ class WhisperModel: ), ) - def generate_with_fallback(self, segment, prompt, tokenizer, options): - features = get_input(segment) + def generate_with_fallback( + self, + segment: np.ndarray, + prompt: List[int], + tokenizer: Tokenizer, + options: TranscriptionOptions, + ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]: + features = get_ctranslate2_storage(segment) result = None avg_log_prob = None final_temperature = None @@ -490,11 +499,11 @@ class WhisperModel: def get_prompt( self, - tokenizer, - previous_tokens, - without_timestamps=False, - prefix=None, - ): + tokenizer: Tokenizer, + previous_tokens: List[int], + without_timestamps: bool = False, + prefix: Optional[str] = None, + ) -> List[int]: prompt = [] if previous_tokens: @@ -582,7 +591,7 @@ class WhisperModel: return [] result = self.model.align( - get_input(mel), + get_ctranslate2_storage(mel), tokenizer.sot_sequence, [text_tokens], num_frames, @@ -635,14 +644,14 @@ class WhisperModel: ] -def get_input(segment): +def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: segment = np.ascontiguousarray(segment) segment = np.expand_dims(segment, 0) segment = ctranslate2.StorageView.from_array(segment) return segment -def get_compression_ratio(text): +def get_compression_ratio(text: str) -> float: text_bytes = text.encode("utf-8") return len(text_bytes) / len(zlib.compress(text_bytes))