diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index cf5ece0..bc30b4c 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -93,6 +93,8 @@ class WhisperModel: num_workers: int = 1, download_root: Optional[str] = None, local_files_only: bool = False, + files: dict = None, + **model_kwargs, ): """Initializes the Whisper model. @@ -119,10 +121,18 @@ class WhisperModel: are saved in the standard Hugging Face cache directory. local_files_only: If True, avoid downloading the file and return the path to the local cached file if it exists. + files: Load model files from the memory. This argument is a dictionary mapping file names + to file contents as file-like or bytes objects. If this is set, model_path acts as an + identifier for this model. """ self.logger = get_logger() - if os.path.isdir(model_size_or_path): + tokenizer_bytes, preprocessor_bytes = None, None + if files: + model_path = model_size_or_path + tokenizer_bytes = files.pop("tokenizer.json", None) + preprocessor_bytes = files.pop("preprocessor_config.json", None) + elif os.path.isdir(model_size_or_path): model_path = model_size_or_path else: model_path = download_model( @@ -138,17 +148,20 @@ class WhisperModel: compute_type=compute_type, intra_threads=cpu_threads, inter_threads=num_workers, + files=files, + **model_kwargs, ) tokenizer_file = os.path.join(model_path, "tokenizer.json") - if os.path.isfile(tokenizer_file): + if tokenizer_bytes: + self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes) + elif os.path.isfile(tokenizer_file): self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) else: self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained( "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") ) - - self.feat_kwargs = self._get_feature_kwargs(model_path) + self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes) self.feature_extractor = FeatureExtractor(**self.feat_kwargs) self.num_samples_per_token = self.feature_extractor.hop_length * 2 self.frames_per_second = ( @@ -166,19 +179,21 @@ class WhisperModel: """The languages supported by the model.""" return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"] - def _get_feature_kwargs(self, model_path) -> dict: - preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json") + def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: config = {} - if os.path.isfile(preprocessor_config_file): - try: - with open(preprocessor_config_file, "r", encoding="utf-8") as json_file: - config = json.load(json_file) - valid_keys = signature(FeatureExtractor.__init__).parameters.keys() - config = {k: v for k, v in config.items() if k in valid_keys} - except json.JSONDecodeError as e: - self.logger.warning( - "Could not load preprocessor_config.json: %s", str(e) - ) + try: + config_path = os.path.join(model_path, "preprocessor_config.json") + if preprocessor_bytes: + config = json.loads(preprocessor_bytes) + elif os.path.isfile(config_path): + with open(config_path, "r", encoding="utf-8") as file: + config = json.load(file) + else: + return config + valid_keys = signature(FeatureExtractor.__init__).parameters.keys() + return {k: v for k, v in config.items() if k in valid_keys} + except json.JSONDecodeError as e: + self.logger.warning("Could not load preprocessor config: %s", e) return config