Add more typing annotations
This commit is contained in:
@@ -6,13 +6,16 @@ system dependencies. FFmpeg does not need to be installed on the system.
|
||||
However, the API is quite low-level so we need to manipulate audio frames directly.
|
||||
"""
|
||||
|
||||
import av
|
||||
import io
|
||||
import itertools
|
||||
|
||||
from typing import BinaryIO, Union
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
|
||||
def decode_audio(input_file, sampling_rate=16000):
|
||||
def decode_audio(input_file: Union[str, BinaryIO], sampling_rate: int = 16000):
|
||||
"""Decodes the audio.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,7 +3,7 @@ import itertools
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from typing import BinaryIO, List, Optional, Tuple, Union
|
||||
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import ctranslate2
|
||||
import numpy as np
|
||||
@@ -14,46 +14,44 @@ from faster_whisper.feature_extractor import FeatureExtractor
|
||||
from faster_whisper.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))):
|
||||
pass
|
||||
class Word(NamedTuple):
|
||||
start: float
|
||||
end: float
|
||||
word: str
|
||||
probability: float
|
||||
|
||||
|
||||
class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))):
|
||||
pass
|
||||
class Segment(NamedTuple):
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: List[Word]
|
||||
|
||||
|
||||
class AudioInfo(
|
||||
collections.namedtuple("AudioInfo", ("language", "language_probability"))
|
||||
):
|
||||
pass
|
||||
class AudioInfo(NamedTuple):
|
||||
language: str
|
||||
language_probability: float
|
||||
|
||||
|
||||
class TranscriptionOptions(
|
||||
collections.namedtuple(
|
||||
"TranscriptionOptions",
|
||||
(
|
||||
"beam_size",
|
||||
"best_of",
|
||||
"patience",
|
||||
"length_penalty",
|
||||
"log_prob_threshold",
|
||||
"no_speech_threshold",
|
||||
"compression_ratio_threshold",
|
||||
"condition_on_previous_text",
|
||||
"temperatures",
|
||||
"initial_prompt",
|
||||
"prefix",
|
||||
"suppress_blank",
|
||||
"suppress_tokens",
|
||||
"without_timestamps",
|
||||
"max_initial_timestamp",
|
||||
"word_timestamps",
|
||||
"prepend_punctuations",
|
||||
"append_punctuations",
|
||||
),
|
||||
)
|
||||
):
|
||||
pass
|
||||
class TranscriptionOptions(NamedTuple):
|
||||
beam_size: int
|
||||
best_of: int
|
||||
patience: float
|
||||
length_penalty: float
|
||||
log_prob_threshold: Optional[float]
|
||||
no_speech_threshold: Optional[float]
|
||||
compression_ratio_threshold: Optional[float]
|
||||
condition_on_previous_text: bool
|
||||
temperatures: List[float]
|
||||
initial_prompt: Optional[str]
|
||||
prefix: Optional[str]
|
||||
suppress_blank: bool
|
||||
suppress_tokens: Optional[List[int]]
|
||||
without_timestamps: bool
|
||||
max_initial_timestamp: float
|
||||
word_timestamps: bool
|
||||
prepend_punctuations: str
|
||||
append_punctuations: str
|
||||
|
||||
|
||||
class WhisperModel:
|
||||
@@ -143,7 +141,7 @@ class WhisperModel:
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
):
|
||||
) -> Tuple[Iterable[Segment], AudioInfo]:
|
||||
"""Transcribes an input file.
|
||||
|
||||
Arguments:
|
||||
@@ -203,7 +201,7 @@ class WhisperModel:
|
||||
language_probability = 1
|
||||
else:
|
||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||
input = get_input(segment)
|
||||
input = get_ctranslate2_storage(segment)
|
||||
results = self.model.detect_language(input)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
@@ -249,7 +247,12 @@ class WhisperModel:
|
||||
|
||||
return segments, audio_info
|
||||
|
||||
def generate_segments(self, features, tokenizer, options):
|
||||
def generate_segments(
|
||||
self,
|
||||
features: np.ndarray,
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Iterable[Segment]:
|
||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
seek = 0
|
||||
all_tokens = []
|
||||
@@ -421,8 +424,14 @@ class WhisperModel:
|
||||
),
|
||||
)
|
||||
|
||||
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
||||
features = get_input(segment)
|
||||
def generate_with_fallback(
|
||||
self,
|
||||
segment: np.ndarray,
|
||||
prompt: List[int],
|
||||
tokenizer: Tokenizer,
|
||||
options: TranscriptionOptions,
|
||||
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
||||
features = get_ctranslate2_storage(segment)
|
||||
result = None
|
||||
avg_log_prob = None
|
||||
final_temperature = None
|
||||
@@ -490,11 +499,11 @@ class WhisperModel:
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
tokenizer,
|
||||
previous_tokens,
|
||||
without_timestamps=False,
|
||||
prefix=None,
|
||||
):
|
||||
tokenizer: Tokenizer,
|
||||
previous_tokens: List[int],
|
||||
without_timestamps: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
) -> List[int]:
|
||||
prompt = []
|
||||
|
||||
if previous_tokens:
|
||||
@@ -582,7 +591,7 @@ class WhisperModel:
|
||||
return []
|
||||
|
||||
result = self.model.align(
|
||||
get_input(mel),
|
||||
get_ctranslate2_storage(mel),
|
||||
tokenizer.sot_sequence,
|
||||
[text_tokens],
|
||||
num_frames,
|
||||
@@ -635,14 +644,14 @@ class WhisperModel:
|
||||
]
|
||||
|
||||
|
||||
def get_input(segment):
|
||||
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||
segment = np.ascontiguousarray(segment)
|
||||
segment = np.expand_dims(segment, 0)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
return segment
|
||||
|
||||
|
||||
def get_compression_ratio(text):
|
||||
def get_compression_ratio(text: str) -> float:
|
||||
text_bytes = text.encode("utf-8")
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user