Make get_input a free function
This commit is contained in:
@@ -179,7 +179,7 @@ class WhisperModel:
|
||||
language_probability = 1
|
||||
else:
|
||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||
input = self.get_input(segment)
|
||||
input = get_input(segment)
|
||||
results = self.model.detect_language(input)
|
||||
language_token, language_probability = results[0][0]
|
||||
language = language_token[2:-2]
|
||||
@@ -356,7 +356,7 @@ class WhisperModel:
|
||||
)
|
||||
|
||||
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
||||
features = self.get_input(segment)
|
||||
features = get_input(segment)
|
||||
result = None
|
||||
avg_log_prob = None
|
||||
final_temperature = None
|
||||
@@ -448,11 +448,12 @@ class WhisperModel:
|
||||
|
||||
return prompt
|
||||
|
||||
def get_input(self, segment):
|
||||
segment = np.ascontiguousarray(segment)
|
||||
segment = np.expand_dims(segment, 0)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
return segment
|
||||
|
||||
def get_input(segment):
|
||||
segment = np.ascontiguousarray(segment)
|
||||
segment = np.expand_dims(segment, 0)
|
||||
segment = ctranslate2.StorageView.from_array(segment)
|
||||
return segment
|
||||
|
||||
|
||||
def get_compression_ratio(text):
|
||||
|
||||
Reference in New Issue
Block a user