Merge remote-tracking branch 'upstream/master' into prompt
This commit is contained in:
Binary file not shown.
@@ -105,6 +105,42 @@ class Tokenizer:
|
||||
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||
|
||||
- ♪♪♪
|
||||
- ( SPEAKING FOREIGN LANGUAGE )
|
||||
- [DAVID] Hey there,
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||
symbols += (
|
||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
)
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.encode(" -")[0], self.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [
|
||||
self.encode(symbol),
|
||||
self.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def split_to_word_tokens(
|
||||
self, tokens: List[int]
|
||||
) -> Tuple[List[str], List[List[int]]]:
|
||||
|
||||
@@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple):
|
||||
max_new_tokens: Optional[int]
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
hotwords: Optional[str]
|
||||
|
||||
|
||||
class TranscriptionInfo(NamedTuple):
|
||||
@@ -92,12 +93,15 @@ class WhisperModel:
|
||||
num_workers: int = 1,
|
||||
download_root: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
files: dict = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""Initializes the Whisper model.
|
||||
|
||||
Args:
|
||||
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
|
||||
small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a
|
||||
small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
|
||||
large-v2, large-v3, large, distil-large-v2 or distil-large-v3), a path to a
|
||||
converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub.
|
||||
When a size or a model ID is configured, the converted model is downloaded
|
||||
from the Hugging Face Hub.
|
||||
@@ -118,10 +122,18 @@ class WhisperModel:
|
||||
are saved in the standard Hugging Face cache directory.
|
||||
local_files_only: If True, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
files: Load model files from the memory. This argument is a dictionary mapping file names
|
||||
to file contents as file-like or bytes objects. If this is set, model_path acts as an
|
||||
identifier for this model.
|
||||
"""
|
||||
self.logger = get_logger()
|
||||
|
||||
if os.path.isdir(model_size_or_path):
|
||||
tokenizer_bytes, preprocessor_bytes = None, None
|
||||
if files:
|
||||
model_path = model_size_or_path
|
||||
tokenizer_bytes = files.pop("tokenizer.json", None)
|
||||
preprocessor_bytes = files.pop("preprocessor_config.json", None)
|
||||
elif os.path.isdir(model_size_or_path):
|
||||
model_path = model_size_or_path
|
||||
else:
|
||||
model_path = download_model(
|
||||
@@ -137,17 +149,20 @@ class WhisperModel:
|
||||
compute_type=compute_type,
|
||||
intra_threads=cpu_threads,
|
||||
inter_threads=num_workers,
|
||||
files=files,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
||||
if os.path.isfile(tokenizer_file):
|
||||
if tokenizer_bytes:
|
||||
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
|
||||
elif os.path.isfile(tokenizer_file):
|
||||
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
||||
else:
|
||||
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
||||
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
||||
)
|
||||
|
||||
self.feat_kwargs = self._get_feature_kwargs(model_path)
|
||||
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
|
||||
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
|
||||
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
||||
self.frames_per_second = (
|
||||
@@ -165,19 +180,21 @@ class WhisperModel:
|
||||
"""The languages supported by the model."""
|
||||
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
|
||||
|
||||
def _get_feature_kwargs(self, model_path) -> dict:
|
||||
preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json")
|
||||
def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
|
||||
config = {}
|
||||
if os.path.isfile(preprocessor_config_file):
|
||||
try:
|
||||
with open(preprocessor_config_file, "r", encoding="utf-8") as json_file:
|
||||
config = json.load(json_file)
|
||||
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
|
||||
config = {k: v for k, v in config.items() if k in valid_keys}
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.warning(
|
||||
"Could not load preprocessor_config.json: %s", str(e)
|
||||
)
|
||||
try:
|
||||
config_path = os.path.join(model_path, "preprocessor_config.json")
|
||||
if preprocessor_bytes:
|
||||
config = json.loads(preprocessor_bytes)
|
||||
elif os.path.isfile(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
config = json.load(file)
|
||||
else:
|
||||
return config
|
||||
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
|
||||
return {k: v for k, v in config.items() if k in valid_keys}
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.warning("Could not load preprocessor config: %s", e)
|
||||
|
||||
return config
|
||||
|
||||
@@ -220,6 +237,7 @@ class WhisperModel:
|
||||
chunk_length: Optional[int] = None,
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
hotwords: Optional[str] = None,
|
||||
language_detection_threshold: Optional[float] = None,
|
||||
language_detection_segments: int = 1,
|
||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||
@@ -259,7 +277,7 @@ class WhisperModel:
|
||||
prefix: Optional text to provide as a prefix for the first window.
|
||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||
of symbols as defined in the model config.json file.
|
||||
of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
without_timestamps: Only sample text tokens.
|
||||
max_initial_timestamp: The initial timestamp cannot be later than this.
|
||||
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
||||
@@ -277,17 +295,18 @@ class WhisperModel:
|
||||
the maximum will be set by the default max_length.
|
||||
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
||||
default chunk_length of the FeatureExtractor.
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
clip_timestamps:
|
||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
|
||||
process. The last end timestamp defaults to the end of the file.
|
||||
vad_filter will be ignored if clip_timestamps is used.
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
hallucination_silence_threshold:
|
||||
When word_timestamps is True, skip silent periods longer than this threshold
|
||||
(in seconds) when a possible hallucination is detected
|
||||
hotwords:
|
||||
Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
||||
language_detection_threshold: If the maximum probability of the language tokens is higher
|
||||
than this value, the language is detected.
|
||||
language_detection_segments: Number of segments to consider for the language detection.
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
|
||||
@@ -351,16 +370,27 @@ class WhisperModel:
|
||||
or language_detection_segments < 1
|
||||
):
|
||||
language_detection_segments = 1
|
||||
seek = 0
|
||||
detected_language_info = {}
|
||||
start_timestamp = (
|
||||
float(clip_timestamps.split(",")[0])
|
||||
if isinstance(clip_timestamps, str)
|
||||
else clip_timestamps[0]
|
||||
)
|
||||
content_frames = (
|
||||
features.shape[-1] - self.feature_extractor.nb_max_frames
|
||||
)
|
||||
while (
|
||||
seek <= content_frames
|
||||
and seek
|
||||
< self.feature_extractor.nb_max_frames * language_detection_segments
|
||||
):
|
||||
seek = (
|
||||
int(start_timestamp * self.frames_per_second)
|
||||
if start_timestamp * self.frames_per_second < content_frames
|
||||
else 0
|
||||
)
|
||||
end_frames = min(
|
||||
seek
|
||||
+ self.feature_extractor.nb_max_frames
|
||||
* language_detection_segments,
|
||||
content_frames,
|
||||
)
|
||||
detected_language_info = {}
|
||||
while seek <= end_frames:
|
||||
segment = features[
|
||||
:, seek : seek + self.feature_extractor.nb_max_frames
|
||||
]
|
||||
@@ -432,7 +462,11 @@ class WhisperModel:
|
||||
initial_prompt=initial_prompt,
|
||||
prefix=prefix,
|
||||
suppress_blank=suppress_blank,
|
||||
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
||||
suppress_tokens=(
|
||||
get_suppressed_tokens(tokenizer, suppress_tokens)
|
||||
if suppress_tokens
|
||||
else suppress_tokens
|
||||
),
|
||||
without_timestamps=without_timestamps,
|
||||
max_initial_timestamp=max_initial_timestamp,
|
||||
word_timestamps=word_timestamps,
|
||||
@@ -441,6 +475,7 @@ class WhisperModel:
|
||||
max_new_tokens=max_new_tokens,
|
||||
clip_timestamps=clip_timestamps,
|
||||
hallucination_silence_threshold=hallucination_silence_threshold,
|
||||
hotwords=hotwords,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||
@@ -457,7 +492,6 @@ class WhisperModel:
|
||||
vad_options=vad_parameters,
|
||||
all_language_probs=all_language_probs,
|
||||
)
|
||||
|
||||
return segments, info
|
||||
|
||||
def generate_segments(
|
||||
@@ -471,14 +505,16 @@ class WhisperModel:
|
||||
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
||||
|
||||
if isinstance(options.clip_timestamps, str):
|
||||
TranscriptionOptions.clip_timestamps = [
|
||||
float(ts)
|
||||
for ts in (
|
||||
options.clip_timestamps.split(",")
|
||||
if options.clip_timestamps
|
||||
else []
|
||||
)
|
||||
]
|
||||
options = options._replace(
|
||||
clip_timestamps=[
|
||||
float(ts)
|
||||
for ts in (
|
||||
options.clip_timestamps.split(",")
|
||||
if options.clip_timestamps
|
||||
else []
|
||||
)
|
||||
]
|
||||
)
|
||||
seek_points: List[int] = [
|
||||
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
||||
]
|
||||
@@ -548,6 +584,7 @@ class WhisperModel:
|
||||
previous_tokens,
|
||||
without_timestamps=options.without_timestamps,
|
||||
prefix=options.prefix if seek == 0 else None,
|
||||
hotwords=options.hotwords,
|
||||
)
|
||||
|
||||
if seek > 0 or encoder_output is None:
|
||||
@@ -948,12 +985,19 @@ class WhisperModel:
|
||||
previous_tokens: List[int],
|
||||
without_timestamps: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
hotwords: Optional[str] = None,
|
||||
) -> List[int]:
|
||||
prompt = []
|
||||
|
||||
if previous_tokens:
|
||||
if previous_tokens or (hotwords and not prefix):
|
||||
prompt.append(tokenizer.sot_prev)
|
||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||
if hotwords and not prefix:
|
||||
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
|
||||
if len(hotwords_tokens) >= self.max_length // 2:
|
||||
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
|
||||
prompt.extend(hotwords_tokens)
|
||||
if previous_tokens:
|
||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||
|
||||
prompt.extend(tokenizer.sot_sequence)
|
||||
|
||||
@@ -1195,15 +1239,16 @@ def get_compression_ratio(text: str) -> float:
|
||||
|
||||
def get_suppressed_tokens(
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Optional[List[int]],
|
||||
suppress_tokens: Tuple[int],
|
||||
) -> Optional[List[int]]:
|
||||
if not suppress_tokens or -1 in suppress_tokens:
|
||||
return suppress_tokens
|
||||
if -1 in suppress_tokens:
|
||||
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||
suppress_tokens.extend(tokenizer.non_speech_tokens)
|
||||
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||
suppress_tokens = [] # interpret empty string as an empty list
|
||||
else:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens = list(suppress_tokens)
|
||||
|
||||
# Ensure the following special tokens are suppressed when the user does
|
||||
# not use the default set (-1).
|
||||
suppress_tokens.extend(
|
||||
[
|
||||
tokenizer.transcribe,
|
||||
@@ -1214,7 +1259,7 @@ def get_suppressed_tokens(
|
||||
]
|
||||
)
|
||||
|
||||
return sorted(set(suppress_tokens))
|
||||
return tuple(sorted(set(suppress_tokens)))
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
||||
|
||||
@@ -54,8 +54,9 @@ def download_model(
|
||||
|
||||
Args:
|
||||
size_or_id: Size of the model to download from https://huggingface.co/Systran
|
||||
(tiny, tiny.en, base, base.en, small, small.en medium, medium.en, large-v1, large-v2,
|
||||
large-v3, large), or a CTranslate2-converted model ID from the Hugging Face Hub
|
||||
(tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en,
|
||||
distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2,
|
||||
distil-large-v3), or a CTranslate2-converted model ID from the Hugging Face Hub
|
||||
(e.g. Systran/faster-whisper-large-v3).
|
||||
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
||||
the cache directory.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import bisect
|
||||
import functools
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
@@ -25,9 +24,6 @@ class VadOptions(NamedTuple):
|
||||
split aggressively just before max_speech_duration_s.
|
||||
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
||||
before separating it
|
||||
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
|
||||
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
||||
Values other than these may affect model performance!!
|
||||
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
||||
"""
|
||||
|
||||
@@ -35,7 +31,6 @@ class VadOptions(NamedTuple):
|
||||
min_speech_duration_ms: int = 250
|
||||
max_speech_duration_s: float = float("inf")
|
||||
min_silence_duration_ms: int = 2000
|
||||
window_size_samples: int = 1024
|
||||
speech_pad_ms: int = 400
|
||||
|
||||
|
||||
@@ -61,15 +56,8 @@ def get_speech_timestamps(
|
||||
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
||||
max_speech_duration_s = vad_options.max_speech_duration_s
|
||||
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
||||
window_size_samples = vad_options.window_size_samples
|
||||
window_size_samples = 512
|
||||
speech_pad_ms = vad_options.speech_pad_ms
|
||||
|
||||
if window_size_samples not in [512, 1024, 1536]:
|
||||
warnings.warn(
|
||||
"Unusual window_size_samples! Supported window_size_samples:\n"
|
||||
" - [512, 1024, 1536] for 16000 sampling_rate"
|
||||
)
|
||||
|
||||
sampling_rate = 16000
|
||||
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
||||
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
@@ -84,14 +72,14 @@ def get_speech_timestamps(
|
||||
audio_length_samples = len(audio)
|
||||
|
||||
model = get_vad_model()
|
||||
state = model.get_initial_state(batch_size=1)
|
||||
state, context = model.get_initial_states(batch_size=1)
|
||||
|
||||
speech_probs = []
|
||||
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
||||
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
|
||||
if len(chunk) < window_size_samples:
|
||||
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
||||
speech_prob, state = model(chunk, state, sampling_rate)
|
||||
speech_prob, state, context = model(chunk, state, context, sampling_rate)
|
||||
speech_probs.append(speech_prob)
|
||||
|
||||
triggered = False
|
||||
@@ -261,12 +249,12 @@ class SileroVADModel:
|
||||
sess_options=opts,
|
||||
)
|
||||
|
||||
def get_initial_state(self, batch_size: int):
|
||||
h = np.zeros((2, batch_size, 64), dtype=np.float32)
|
||||
c = np.zeros((2, batch_size, 64), dtype=np.float32)
|
||||
return h, c
|
||||
def get_initial_states(self, batch_size: int):
|
||||
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
||||
context = np.zeros((batch_size, 64), dtype=np.float32)
|
||||
return state, context
|
||||
|
||||
def __call__(self, x, state, sr: int):
|
||||
def __call__(self, x, state, context, sr: int):
|
||||
if len(x.shape) == 1:
|
||||
x = np.expand_dims(x, 0)
|
||||
if len(x.shape) > 2:
|
||||
@@ -276,16 +264,15 @@ class SileroVADModel:
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
h, c = state
|
||||
x = np.concatenate([context, x], axis=1)
|
||||
|
||||
ort_inputs = {
|
||||
"input": x,
|
||||
"h": h,
|
||||
"c": c,
|
||||
"state": state,
|
||||
"sr": np.array(sr, dtype="int64"),
|
||||
}
|
||||
|
||||
out, h, c = self.session.run(None, ort_inputs)
|
||||
state = (h, c)
|
||||
out, state = self.session.run(None, ort_inputs)
|
||||
context = x[..., -64:]
|
||||
|
||||
return out, state
|
||||
return out, state, context
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Version information."""
|
||||
|
||||
__version__ = "1.0.1"
|
||||
__version__ = "1.0.3"
|
||||
|
||||
Reference in New Issue
Block a user