Make get_input a free function
This commit is contained in:
@@ -179,7 +179,7 @@ class WhisperModel:
|
|||||||
language_probability = 1
|
language_probability = 1
|
||||||
else:
|
else:
|
||||||
segment = features[:, : self.feature_extractor.nb_max_frames]
|
segment = features[:, : self.feature_extractor.nb_max_frames]
|
||||||
input = self.get_input(segment)
|
input = get_input(segment)
|
||||||
results = self.model.detect_language(input)
|
results = self.model.detect_language(input)
|
||||||
language_token, language_probability = results[0][0]
|
language_token, language_probability = results[0][0]
|
||||||
language = language_token[2:-2]
|
language = language_token[2:-2]
|
||||||
@@ -356,7 +356,7 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
def generate_with_fallback(self, segment, prompt, tokenizer, options):
|
||||||
features = self.get_input(segment)
|
features = get_input(segment)
|
||||||
result = None
|
result = None
|
||||||
avg_log_prob = None
|
avg_log_prob = None
|
||||||
final_temperature = None
|
final_temperature = None
|
||||||
@@ -448,11 +448,12 @@ class WhisperModel:
|
|||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def get_input(self, segment):
|
|
||||||
segment = np.ascontiguousarray(segment)
|
def get_input(segment):
|
||||||
segment = np.expand_dims(segment, 0)
|
segment = np.ascontiguousarray(segment)
|
||||||
segment = ctranslate2.StorageView.from_array(segment)
|
segment = np.expand_dims(segment, 0)
|
||||||
return segment
|
segment = ctranslate2.StorageView.from_array(segment)
|
||||||
|
return segment
|
||||||
|
|
||||||
|
|
||||||
def get_compression_ratio(text):
|
def get_compression_ratio(text):
|
||||||
|
|||||||
Reference in New Issue
Block a user