Support initializing more whisper model args (#807)
This commit is contained in:
@@ -93,6 +93,8 @@ class WhisperModel:
|
|||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
download_root: Optional[str] = None,
|
download_root: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
files: dict = None,
|
||||||
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
"""Initializes the Whisper model.
|
"""Initializes the Whisper model.
|
||||||
|
|
||||||
@@ -119,10 +121,18 @@ class WhisperModel:
|
|||||||
are saved in the standard Hugging Face cache directory.
|
are saved in the standard Hugging Face cache directory.
|
||||||
local_files_only: If True, avoid downloading the file and return the path to the
|
local_files_only: If True, avoid downloading the file and return the path to the
|
||||||
local cached file if it exists.
|
local cached file if it exists.
|
||||||
|
files: Load model files from the memory. This argument is a dictionary mapping file names
|
||||||
|
to file contents as file-like or bytes objects. If this is set, model_path acts as an
|
||||||
|
identifier for this model.
|
||||||
"""
|
"""
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
if os.path.isdir(model_size_or_path):
|
tokenizer_bytes, preprocessor_bytes = None, None
|
||||||
|
if files:
|
||||||
|
model_path = model_size_or_path
|
||||||
|
tokenizer_bytes = files.pop("tokenizer.json", None)
|
||||||
|
preprocessor_bytes = files.pop("preprocessor_config.json", None)
|
||||||
|
elif os.path.isdir(model_size_or_path):
|
||||||
model_path = model_size_or_path
|
model_path = model_size_or_path
|
||||||
else:
|
else:
|
||||||
model_path = download_model(
|
model_path = download_model(
|
||||||
@@ -138,17 +148,20 @@ class WhisperModel:
|
|||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
intra_threads=cpu_threads,
|
intra_threads=cpu_threads,
|
||||||
inter_threads=num_workers,
|
inter_threads=num_workers,
|
||||||
|
files=files,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
||||||
if os.path.isfile(tokenizer_file):
|
if tokenizer_bytes:
|
||||||
|
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
|
||||||
|
elif os.path.isfile(tokenizer_file):
|
||||||
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
||||||
else:
|
else:
|
||||||
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
||||||
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
||||||
)
|
)
|
||||||
|
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
|
||||||
self.feat_kwargs = self._get_feature_kwargs(model_path)
|
|
||||||
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
|
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
|
||||||
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
||||||
self.frames_per_second = (
|
self.frames_per_second = (
|
||||||
@@ -166,19 +179,21 @@ class WhisperModel:
|
|||||||
"""The languages supported by the model."""
|
"""The languages supported by the model."""
|
||||||
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
|
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
|
||||||
|
|
||||||
def _get_feature_kwargs(self, model_path) -> dict:
|
def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
|
||||||
preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json")
|
|
||||||
config = {}
|
config = {}
|
||||||
if os.path.isfile(preprocessor_config_file):
|
try:
|
||||||
try:
|
config_path = os.path.join(model_path, "preprocessor_config.json")
|
||||||
with open(preprocessor_config_file, "r", encoding="utf-8") as json_file:
|
if preprocessor_bytes:
|
||||||
config = json.load(json_file)
|
config = json.loads(preprocessor_bytes)
|
||||||
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
|
elif os.path.isfile(config_path):
|
||||||
config = {k: v for k, v in config.items() if k in valid_keys}
|
with open(config_path, "r", encoding="utf-8") as file:
|
||||||
except json.JSONDecodeError as e:
|
config = json.load(file)
|
||||||
self.logger.warning(
|
else:
|
||||||
"Could not load preprocessor_config.json: %s", str(e)
|
return config
|
||||||
)
|
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
|
||||||
|
return {k: v for k, v in config.items() if k in valid_keys}
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.warning("Could not load preprocessor config: %s", e)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user