Make get_input a free function

This commit is contained in:
Guillaume Klein
2023-03-09 12:54:41 +01:00
parent c52adaca90
commit 3301dd9273

View File

@@ -179,7 +179,7 @@ class WhisperModel:
language_probability = 1 language_probability = 1
else: else:
segment = features[:, : self.feature_extractor.nb_max_frames] segment = features[:, : self.feature_extractor.nb_max_frames]
input = self.get_input(segment) input = get_input(segment)
results = self.model.detect_language(input) results = self.model.detect_language(input)
language_token, language_probability = results[0][0] language_token, language_probability = results[0][0]
language = language_token[2:-2] language = language_token[2:-2]
@@ -356,7 +356,7 @@ class WhisperModel:
) )
def generate_with_fallback(self, segment, prompt, tokenizer, options): def generate_with_fallback(self, segment, prompt, tokenizer, options):
features = self.get_input(segment) features = get_input(segment)
result = None result = None
avg_log_prob = None avg_log_prob = None
final_temperature = None final_temperature = None
@@ -448,11 +448,12 @@ class WhisperModel:
return prompt return prompt
def get_input(self, segment):
segment = np.ascontiguousarray(segment) def get_input(segment):
segment = np.expand_dims(segment, 0) segment = np.ascontiguousarray(segment)
segment = ctranslate2.StorageView.from_array(segment) segment = np.expand_dims(segment, 0)
return segment segment = ctranslate2.StorageView.from_array(segment)
return segment
def get_compression_ratio(text): def get_compression_ratio(text):