Make get_input a free function

This commit is contained in:
Guillaume Klein
2023-03-09 12:54:41 +01:00
parent c52adaca90
commit 3301dd9273

View File

@@ -179,7 +179,7 @@ class WhisperModel:
language_probability = 1
else:
segment = features[:, : self.feature_extractor.nb_max_frames]
input = self.get_input(segment)
input = get_input(segment)
results = self.model.detect_language(input)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
@@ -356,7 +356,7 @@ class WhisperModel:
)
def generate_with_fallback(self, segment, prompt, tokenizer, options):
features = self.get_input(segment)
features = get_input(segment)
result = None
avg_log_prob = None
final_temperature = None
@@ -448,11 +448,12 @@ class WhisperModel:
return prompt
def get_input(self, segment):
segment = np.ascontiguousarray(segment)
segment = np.expand_dims(segment, 0)
segment = ctranslate2.StorageView.from_array(segment)
return segment
def get_input(segment):
segment = np.ascontiguousarray(segment)
segment = np.expand_dims(segment, 0)
segment = ctranslate2.StorageView.from_array(segment)
return segment
def get_compression_ratio(text):