Merge remote-tracking branch 'upstream/master' into prompt

This commit is contained in:
2024-07-10 10:16:35 +08:00
19 changed files with 2255 additions and 84 deletions

View File

@@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple):
max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]
hotwords: Optional[str]
class TranscriptionInfo(NamedTuple):
@@ -92,12 +93,15 @@ class WhisperModel:
num_workers: int = 1,
download_root: Optional[str] = None,
local_files_only: bool = False,
files: dict = None,
**model_kwargs,
):
"""Initializes the Whisper model.
Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a
small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
large-v2, large-v3, large, distil-large-v2 or distil-large-v3), a path to a
converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub.
When a size or a model ID is configured, the converted model is downloaded
from the Hugging Face Hub.
@@ -118,10 +122,18 @@ class WhisperModel:
are saved in the standard Hugging Face cache directory.
local_files_only: If True, avoid downloading the file and return the path to the
local cached file if it exists.
files: Load model files from the memory. This argument is a dictionary mapping file names
to file contents as file-like or bytes objects. If this is set, model_path acts as an
identifier for this model.
"""
self.logger = get_logger()
if os.path.isdir(model_size_or_path):
tokenizer_bytes, preprocessor_bytes = None, None
if files:
model_path = model_size_or_path
tokenizer_bytes = files.pop("tokenizer.json", None)
preprocessor_bytes = files.pop("preprocessor_config.json", None)
elif os.path.isdir(model_size_or_path):
model_path = model_size_or_path
else:
model_path = download_model(
@@ -137,17 +149,20 @@ class WhisperModel:
compute_type=compute_type,
intra_threads=cpu_threads,
inter_threads=num_workers,
files=files,
**model_kwargs,
)
tokenizer_file = os.path.join(model_path, "tokenizer.json")
if os.path.isfile(tokenizer_file):
if tokenizer_bytes:
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
elif os.path.isfile(tokenizer_file):
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
else:
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
)
self.feat_kwargs = self._get_feature_kwargs(model_path)
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
self.num_samples_per_token = self.feature_extractor.hop_length * 2
self.frames_per_second = (
@@ -165,19 +180,21 @@ class WhisperModel:
"""The languages supported by the model."""
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
def _get_feature_kwargs(self, model_path) -> dict:
preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json")
def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
config = {}
if os.path.isfile(preprocessor_config_file):
try:
with open(preprocessor_config_file, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
config = {k: v for k, v in config.items() if k in valid_keys}
except json.JSONDecodeError as e:
self.logger.warning(
"Could not load preprocessor_config.json: %s", str(e)
)
try:
config_path = os.path.join(model_path, "preprocessor_config.json")
if preprocessor_bytes:
config = json.loads(preprocessor_bytes)
elif os.path.isfile(config_path):
with open(config_path, "r", encoding="utf-8") as file:
config = json.load(file)
else:
return config
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
return {k: v for k, v in config.items() if k in valid_keys}
except json.JSONDecodeError as e:
self.logger.warning("Could not load preprocessor config: %s", e)
return config
@@ -220,6 +237,7 @@ class WhisperModel:
chunk_length: Optional[int] = None,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = None,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
@@ -259,7 +277,7 @@ class WhisperModel:
prefix: Optional text to provide as a prefix for the first window.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in the model config.json file.
of symbols as defined in `tokenizer.non_speech_tokens()`
without_timestamps: Only sample text tokens.
max_initial_timestamp: The initial timestamp cannot be later than this.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
@@ -277,17 +295,18 @@ class WhisperModel:
the maximum will be set by the default max_length.
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
clip_timestamps: Union[str, List[float]]
clip_timestamps:
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
process. The last end timestamp defaults to the end of the file.
vad_filter will be ignored if clip_timestamps is used.
hallucination_silence_threshold: Optional[float]
hallucination_silence_threshold:
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected
hotwords:
Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
language_detection_threshold: If the maximum probability of the language tokens is higher
than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
A tuple with:
@@ -351,16 +370,27 @@ class WhisperModel:
or language_detection_segments < 1
):
language_detection_segments = 1
seek = 0
detected_language_info = {}
start_timestamp = (
float(clip_timestamps.split(",")[0])
if isinstance(clip_timestamps, str)
else clip_timestamps[0]
)
content_frames = (
features.shape[-1] - self.feature_extractor.nb_max_frames
)
while (
seek <= content_frames
and seek
< self.feature_extractor.nb_max_frames * language_detection_segments
):
seek = (
int(start_timestamp * self.frames_per_second)
if start_timestamp * self.frames_per_second < content_frames
else 0
)
end_frames = min(
seek
+ self.feature_extractor.nb_max_frames
* language_detection_segments,
content_frames,
)
detected_language_info = {}
while seek <= end_frames:
segment = features[
:, seek : seek + self.feature_extractor.nb_max_frames
]
@@ -432,7 +462,11 @@ class WhisperModel:
initial_prompt=initial_prompt,
prefix=prefix,
suppress_blank=suppress_blank,
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
suppress_tokens=(
get_suppressed_tokens(tokenizer, suppress_tokens)
if suppress_tokens
else suppress_tokens
),
without_timestamps=without_timestamps,
max_initial_timestamp=max_initial_timestamp,
word_timestamps=word_timestamps,
@@ -441,6 +475,7 @@ class WhisperModel:
max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
hotwords=hotwords,
)
segments = self.generate_segments(features, tokenizer, options, encoder_output)
@@ -457,7 +492,6 @@ class WhisperModel:
vad_options=vad_parameters,
all_language_probs=all_language_probs,
)
return segments, info
def generate_segments(
@@ -471,14 +505,16 @@ class WhisperModel:
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
if isinstance(options.clip_timestamps, str):
TranscriptionOptions.clip_timestamps = [
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]
options = options._replace(
clip_timestamps=[
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]
)
seek_points: List[int] = [
round(ts * self.frames_per_second) for ts in options.clip_timestamps
]
@@ -548,6 +584,7 @@ class WhisperModel:
previous_tokens,
without_timestamps=options.without_timestamps,
prefix=options.prefix if seek == 0 else None,
hotwords=options.hotwords,
)
if seek > 0 or encoder_output is None:
@@ -948,12 +985,19 @@ class WhisperModel:
previous_tokens: List[int],
without_timestamps: bool = False,
prefix: Optional[str] = None,
hotwords: Optional[str] = None,
) -> List[int]:
prompt = []
if previous_tokens:
if previous_tokens or (hotwords and not prefix):
prompt.append(tokenizer.sot_prev)
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
if hotwords and not prefix:
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
if len(hotwords_tokens) >= self.max_length // 2:
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
prompt.extend(hotwords_tokens)
if previous_tokens:
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt.extend(tokenizer.sot_sequence)
@@ -1195,15 +1239,16 @@ def get_compression_ratio(text: str) -> float:
def get_suppressed_tokens(
tokenizer: Tokenizer,
suppress_tokens: Optional[List[int]],
suppress_tokens: Tuple[int],
) -> Optional[List[int]]:
if not suppress_tokens or -1 in suppress_tokens:
return suppress_tokens
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens = list(suppress_tokens)
# Ensure the following special tokens are suppressed when the user does
# not use the default set (-1).
suppress_tokens.extend(
[
tokenizer.transcribe,
@@ -1214,7 +1259,7 @@ def get_suppressed_tokens(
]
)
return sorted(set(suppress_tokens))
return tuple(sorted(set(suppress_tokens)))
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: