diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index de9a0f1..6291fa2 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -1,4 +1,5 @@ import collections +import os import zlib from typing import BinaryIO, List, Optional, Tuple, Union @@ -81,10 +82,15 @@ class WhisperModel: inter_threads=num_workers, ) + tokenizer_file = os.path.join(model_path, "tokenizer.json") + if os.path.isfile(tokenizer_file): + self.tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) + else: + self.tokenizer = tokenizers.Tokenizer.from_pretrained( + "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") + ) + self.feature_extractor = FeatureExtractor() - self.tokenizer = tokenizers.Tokenizer.from_pretrained( - "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") - ) self.eot_id = self.tokenizer.token_to_id("<|endoftext|>") self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1 self.input_stride = 2