support distil-whisper (#557)
This commit is contained in:
@@ -66,6 +66,7 @@ class TranscriptionOptions(NamedTuple):
|
||||
word_timestamps: bool
|
||||
prepend_punctuations: str
|
||||
append_punctuations: str
|
||||
max_new_tokens: Optional[int]
|
||||
|
||||
|
||||
class TranscriptionInfo(NamedTuple):
|
||||
@@ -213,6 +214,8 @@ class WhisperModel:
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
vad_filter: bool = False,
|
||||
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
chunk_length: Optional[int] = None,
|
||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||
"""Transcribes an input file.
|
||||
|
||||
@@ -264,6 +267,10 @@ class WhisperModel:
|
||||
https://github.com/snakers4/silero-vad.
|
||||
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
||||
parameters and default values in the class `VadOptions`).
|
||||
max_new_tokens: Maximum number of new tokens to generate. If not set, the maximum will be
|
||||
set by the default max_length.
|
||||
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
||||
default chunk_length of the FeatureExtractor.
|
||||
|
||||
Returns:
|
||||
A tuple with:
|
||||
@@ -313,7 +320,7 @@ class WhisperModel:
|
||||
else:
|
||||
speech_chunks = None
|
||||
|
||||
features = self.feature_extractor(audio)
|
||||
features = self.feature_extractor(audio, chunk_length=chunk_length)
|
||||
|
||||
encoder_output = None
|
||||
all_language_probs = None
|
||||
@@ -379,6 +386,7 @@ class WhisperModel:
|
||||
word_timestamps=word_timestamps,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||
@@ -642,6 +650,21 @@ class WhisperModel:
|
||||
max_initial_timestamp_index = int(
|
||||
round(options.max_initial_timestamp / self.time_precision)
|
||||
)
|
||||
if options.max_new_tokens is not None:
|
||||
max_length = len(prompt) + options.max_new_tokens
|
||||
else:
|
||||
max_length = self.max_length
|
||||
|
||||
if max_length > self.max_length:
|
||||
raise ValueError(
|
||||
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
|
||||
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
|
||||
f"and `max_new_tokens` is: {max_length}. This exceeds the "
|
||||
f"`max_length` of the Whisper model: {self.max_length}. "
|
||||
"You should either reduce the length of your prompt, or "
|
||||
"reduce the value of `max_new_tokens`, "
|
||||
f"so that their combined length is less that {self.max_length}."
|
||||
)
|
||||
|
||||
for temperature in options.temperatures:
|
||||
if temperature > 0:
|
||||
@@ -663,7 +686,7 @@ class WhisperModel:
|
||||
length_penalty=options.length_penalty,
|
||||
repetition_penalty=options.repetition_penalty,
|
||||
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
||||
max_length=self.max_length,
|
||||
max_length=max_length,
|
||||
return_scores=True,
|
||||
return_no_speech_prob=True,
|
||||
suppress_blank=options.suppress_blank,
|
||||
|
||||
Reference in New Issue
Block a user