From 3301dd9273ae4720e3c7bc40ccdef038f6748da7 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 9 Mar 2023 12:54:41 +0100 Subject: [PATCH] Make get_input a free function --- faster_whisper/transcribe.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index e27d3a8..419111b 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -179,7 +179,7 @@ class WhisperModel: language_probability = 1 else: segment = features[:, : self.feature_extractor.nb_max_frames] - input = self.get_input(segment) + input = get_input(segment) results = self.model.detect_language(input) language_token, language_probability = results[0][0] language = language_token[2:-2] @@ -356,7 +356,7 @@ class WhisperModel: ) def generate_with_fallback(self, segment, prompt, tokenizer, options): - features = self.get_input(segment) + features = get_input(segment) result = None avg_log_prob = None final_temperature = None @@ -448,11 +448,12 @@ class WhisperModel: return prompt - def get_input(self, segment): - segment = np.ascontiguousarray(segment) - segment = np.expand_dims(segment, 0) - segment = ctranslate2.StorageView.from_array(segment) - return segment + +def get_input(segment): + segment = np.ascontiguousarray(segment) + segment = np.expand_dims(segment, 0) + segment = ctranslate2.StorageView.from_array(segment) + return segment def get_compression_ratio(text):