Add more typing annotations
This commit is contained in:
@@ -6,13 +6,16 @@ system dependencies. FFmpeg does not need to be installed on the system.
|
|||||||
However, the API is quite low-level so we need to manipulate audio frames directly.
|
However, the API is quite low-level so we need to manipulate audio frames directly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import av
|
|
||||||
import io
|
import io
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
from typing import BinaryIO, Union
|
||||||
|
|
||||||
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def decode_audio(input_file, sampling_rate=16000):
|
def decode_audio(input_file: Union[str, BinaryIO], sampling_rate: int = 16000):
|
||||||
"""Decodes the audio.
|
"""Decodes the audio.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import itertools
|
|||||||
import os
|
import os
|
||||||
import zlib
|
import zlib
|
||||||
|
|
||||||
from typing import BinaryIO, List, Optional, Tuple, Union
|
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -14,46 +14,44 @@ from faster_whisper.feature_extractor import FeatureExtractor
|
|||||||
from faster_whisper.tokenizer import Tokenizer
|
from faster_whisper.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
class Segment(collections.namedtuple("Segment", ("start", "end", "text", "words"))):
|
class Word(NamedTuple):
|
||||||
pass
|
start: float
|
||||||
|
end: float
|
||||||
|
word: str
|
||||||
|
probability: float
|
||||||
|
|
||||||
|
|
||||||
class Word(collections.namedtuple("Word", ("start", "end", "word", "probability"))):
|
class Segment(NamedTuple):
|
||||||
pass
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
words: List[Word]
|
||||||
|
|
||||||
|
|
||||||
class AudioInfo(
|
class AudioInfo(NamedTuple):
|
||||||
collections.namedtuple("AudioInfo", ("language", "language_probability"))
|
language: str
|
||||||
):
|
language_probability: float
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionOptions(
|
class TranscriptionOptions(NamedTuple):
|
||||||
collections.namedtuple(
|
beam_size: int
|
||||||
"TranscriptionOptions",
|
best_of: int
|
||||||
(
|
patience: float
|
||||||
"beam_size",
|
length_penalty: float
|
||||||
"best_of",
|
log_prob_threshold: Optional[float]
|
||||||
"patience",
|
no_speech_threshold: Optional[float]
|
||||||
"length_penalty",
|
compression_ratio_threshold: Optional[float]
|
||||||
"log_prob_threshold",
|
condition_on_previous_text: bool
|
||||||
"no_speech_threshold",
|
temperatures: List[float]
|
||||||
"compression_ratio_threshold",
|
initial_prompt: Optional[str]
|
||||||
"condition_on_previous_text",
|
prefix: Optional[str]
|
||||||
"temperatures",
|
suppress_blank: bool
|
||||||
"initial_prompt",
|
suppress_tokens: Optional[List[int]]
|
||||||
"prefix",
|
without_timestamps: bool
|
||||||
"suppress_blank",
|
max_initial_timestamp: float
|
||||||
"suppress_tokens",
|
word_timestamps: bool
|
||||||
"without_timestamps",
|
prepend_punctuations: str
|
||||||
"max_initial_timestamp",
|
append_punctuations: str
|
||||||
"word_timestamps",
|
|
||||||
"prepend_punctuations",
|
|
||||||
"append_punctuations",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperModel:
|
class WhisperModel:
|
||||||
@@ -143,7 +141,7 @@ class WhisperModel:
|
|||||||
word_timestamps: bool = False,
|
word_timestamps: bool = False,
|
||||||
prepend_punctuations: str = "\"'“¿([{-",
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
):
|
) -> Tuple[Iterable[Segment], AudioInfo]:
|
||||||
"""Transcribes an input file.
|
"""Transcribes an input file.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@@ -203,7 +201,7 @@ class WhisperModel:
|
|||||||
language_probability = 1
|
language_probability = 1
|
||||||
else:
|
else:
|
||||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||||
input = get_input(segment)
|
input = get_ctranslate2_storage(segment)
|
||||||
results = self.model.detect_language(input)
|
results = self.model.detect_language(input)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
@@ -249,7 +247,12 @@ class WhisperModel:
|
|||||||
|
|
||||||
return segments, audio_info
|
return segments, audio_info
|
||||||
|
|
||||||
def generate_segments(self, features, tokenizer, options):
|
def generate_segments(
|
||||||
|
self,
|
||||||
|
features: np.ndarray,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
options: TranscriptionOptions,
|
||||||
|
) -> Iterable[Segment]:
|
||||||
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||||
seek = 0
|
seek = 0
|
||||||
all_tokens = []
|
all_tokens = []
|
||||||
@@ -421,8 +424,14 @@ class WhisperModel:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
def generate_with_fallback(
|
||||||
features = get_input(segment)
|
self,
|
||||||
|
segment: np.ndarray,
|
||||||
|
prompt: List[int],
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
options: TranscriptionOptions,
|
||||||
|
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float]:
|
||||||
|
features = get_ctranslate2_storage(segment)
|
||||||
result = None
|
result = None
|
||||||
avg_log_prob = None
|
avg_log_prob = None
|
||||||
final_temperature = None
|
final_temperature = None
|
||||||
@@ -490,11 +499,11 @@ class WhisperModel:
|
|||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer: Tokenizer,
|
||||||
previous_tokens,
|
previous_tokens: List[int],
|
||||||
without_timestamps=False,
|
without_timestamps: bool = False,
|
||||||
prefix=None,
|
prefix: Optional[str] = None,
|
||||||
):
|
) -> List[int]:
|
||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
if previous_tokens:
|
if previous_tokens:
|
||||||
@@ -582,7 +591,7 @@ class WhisperModel:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
result = self.model.align(
|
result = self.model.align(
|
||||||
get_input(mel),
|
get_ctranslate2_storage(mel),
|
||||||
tokenizer.sot_sequence,
|
tokenizer.sot_sequence,
|
||||||
[text_tokens],
|
[text_tokens],
|
||||||
num_frames,
|
num_frames,
|
||||||
@@ -635,14 +644,14 @@ class WhisperModel:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_input(segment):
|
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
||||||
segment = np.ascontiguousarray(segment)
|
segment = np.ascontiguousarray(segment)
|
||||||
segment = np.expand_dims(segment, 0)
|
segment = np.expand_dims(segment, 0)
|
||||||
segment = ctranslate2.StorageView.from_array(segment)
|
segment = ctranslate2.StorageView.from_array(segment)
|
||||||
return segment
|
return segment
|
||||||
|
|
||||||
|
|
||||||
def get_compression_ratio(text):
|
def get_compression_ratio(text: str) -> float:
|
||||||
text_bytes = text.encode("utf-8")
|
text_bytes = text.encode("utf-8")
|
||||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user