Add without_timestamps parameter

This commit is contained in:
Guillaume Klein
2023-02-13 21:22:05 +01:00
parent 5e938cba4e
commit f56dfc6491

View File

@@ -33,6 +33,7 @@ class TranscriptionOptions(
"condition_on_previous_text", "condition_on_previous_text",
"temperatures", "temperatures",
"initial_prompt", "initial_prompt",
"without_timestamps",
), ),
) )
): ):
@@ -101,6 +102,7 @@ class WhisperModel:
no_speech_threshold=0.6, no_speech_threshold=0.6,
condition_on_previous_text=True, condition_on_previous_text=True,
initial_prompt=None, initial_prompt=None,
without_timestamps=False,
): ):
"""Transcribes an input file. """Transcribes an input file.
@@ -126,6 +128,7 @@ class WhisperModel:
windows, but the model becomes less prone to getting stuck in a failure loop, windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync. such as repetition looping or timestamps going out of sync.
initial_prompt: Optional text to provide as a prompt for the first window. initial_prompt: Optional text to provide as a prompt for the first window.
without_timestamps: Only sample text tokens.
Returns: Returns:
A tuple with: A tuple with:
@@ -159,6 +162,7 @@ class WhisperModel:
temperature if isinstance(temperature, (list, tuple)) else [temperature] temperature if isinstance(temperature, (list, tuple)) else [temperature]
), ),
initial_prompt=initial_prompt, initial_prompt=initial_prompt,
without_timestamps=without_timestamps,
) )
segments = self.generate_segments(features, language, options) segments = self.generate_segments(features, language, options)
@@ -205,7 +209,9 @@ class WhisperModel:
segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame
previous_tokens = all_tokens[prompt_reset_since:] previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(language, previous_tokens) prompt = self.get_prompt(
language, previous_tokens, without_timestamps=options.without_timestamps
)
result, temperature = self.generate_with_fallback(segment, prompt, options) result, temperature = self.generate_with_fallback(segment, prompt, options)
if ( if (
@@ -317,7 +323,7 @@ class WhisperModel:
return result, final_temperature return result, final_temperature
def get_prompt(self, language, previous_tokens): def get_prompt(self, language, previous_tokens, without_timestamps=False):
prompt = [] prompt = []
if previous_tokens: if previous_tokens:
@@ -330,6 +336,9 @@ class WhisperModel:
self.tokenizer.token_to_id("<|transcribe|>"), self.tokenizer.token_to_id("<|transcribe|>"),
] ]
if without_timestamps:
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
return prompt return prompt
def get_segment(self, features, offset=0): def get_segment(self, features, offset=0):