Support initializing more whisper model args (#807)

This commit is contained in:
trungkienbkhn
2024-05-04 15:12:59 +07:00
committed by GitHub
parent 6eec07739e
commit 8d5e6d56d9

View File

@@ -93,6 +93,8 @@ class WhisperModel:
num_workers: int = 1, num_workers: int = 1,
download_root: Optional[str] = None, download_root: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
files: dict = None,
**model_kwargs,
): ):
"""Initializes the Whisper model. """Initializes the Whisper model.
@@ -119,10 +121,18 @@ class WhisperModel:
are saved in the standard Hugging Face cache directory. are saved in the standard Hugging Face cache directory.
local_files_only: If True, avoid downloading the file and return the path to the local_files_only: If True, avoid downloading the file and return the path to the
local cached file if it exists. local cached file if it exists.
files: Load model files from the memory. This argument is a dictionary mapping file names
to file contents as file-like or bytes objects. If this is set, model_path acts as an
identifier for this model.
""" """
self.logger = get_logger() self.logger = get_logger()
if os.path.isdir(model_size_or_path): tokenizer_bytes, preprocessor_bytes = None, None
if files:
model_path = model_size_or_path
tokenizer_bytes = files.pop("tokenizer.json", None)
preprocessor_bytes = files.pop("preprocessor_config.json", None)
elif os.path.isdir(model_size_or_path):
model_path = model_size_or_path model_path = model_size_or_path
else: else:
model_path = download_model( model_path = download_model(
@@ -138,17 +148,20 @@ class WhisperModel:
compute_type=compute_type, compute_type=compute_type,
intra_threads=cpu_threads, intra_threads=cpu_threads,
inter_threads=num_workers, inter_threads=num_workers,
files=files,
**model_kwargs,
) )
tokenizer_file = os.path.join(model_path, "tokenizer.json") tokenizer_file = os.path.join(model_path, "tokenizer.json")
if os.path.isfile(tokenizer_file): if tokenizer_bytes:
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
elif os.path.isfile(tokenizer_file):
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
else: else:
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained( self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
) )
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
self.feat_kwargs = self._get_feature_kwargs(model_path)
self.feature_extractor = FeatureExtractor(**self.feat_kwargs) self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
self.num_samples_per_token = self.feature_extractor.hop_length * 2 self.num_samples_per_token = self.feature_extractor.hop_length * 2
self.frames_per_second = ( self.frames_per_second = (
@@ -166,19 +179,21 @@ class WhisperModel:
"""The languages supported by the model.""" """The languages supported by the model."""
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"] return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
def _get_feature_kwargs(self, model_path) -> dict: def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json")
config = {} config = {}
if os.path.isfile(preprocessor_config_file): try:
try: config_path = os.path.join(model_path, "preprocessor_config.json")
with open(preprocessor_config_file, "r", encoding="utf-8") as json_file: if preprocessor_bytes:
config = json.load(json_file) config = json.loads(preprocessor_bytes)
valid_keys = signature(FeatureExtractor.__init__).parameters.keys() elif os.path.isfile(config_path):
config = {k: v for k, v in config.items() if k in valid_keys} with open(config_path, "r", encoding="utf-8") as file:
except json.JSONDecodeError as e: config = json.load(file)
self.logger.warning( else:
"Could not load preprocessor_config.json: %s", str(e) return config
) valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
return {k: v for k, v in config.items() if k in valid_keys}
except json.JSONDecodeError as e:
self.logger.warning("Could not load preprocessor config: %s", e)
return config return config