Add without_timestamps parameter

This commit is contained in:
Guillaume Klein
2023-02-13 21:22:05 +01:00
parent 5e938cba4e
commit f56dfc6491

View File

@@ -33,6 +33,7 @@ class TranscriptionOptions(
"condition_on_previous_text",
"temperatures",
"initial_prompt",
"without_timestamps",
),
)
):
@@ -101,6 +102,7 @@ class WhisperModel:
no_speech_threshold=0.6,
condition_on_previous_text=True,
initial_prompt=None,
without_timestamps=False,
):
"""Transcribes an input file.
@@ -126,6 +128,7 @@ class WhisperModel:
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync.
initial_prompt: Optional text to provide as a prompt for the first window.
without_timestamps: Only sample text tokens.
Returns:
A tuple with:
@@ -159,6 +162,7 @@ class WhisperModel:
temperature if isinstance(temperature, (list, tuple)) else [temperature]
),
initial_prompt=initial_prompt,
without_timestamps=without_timestamps,
)
segments = self.generate_segments(features, language, options)
@@ -205,7 +209,9 @@ class WhisperModel:
segment_duration = segment.shape[-1] * self.feature_extractor.time_per_frame
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(language, previous_tokens)
prompt = self.get_prompt(
language, previous_tokens, without_timestamps=options.without_timestamps
)
result, temperature = self.generate_with_fallback(segment, prompt, options)
if (
@@ -317,7 +323,7 @@ class WhisperModel:
return result, final_temperature
def get_prompt(self, language, previous_tokens):
def get_prompt(self, language, previous_tokens, without_timestamps=False):
prompt = []
if previous_tokens:
@@ -330,6 +336,9 @@ class WhisperModel:
self.tokenizer.token_to_id("<|transcribe|>"),
]
if without_timestamps:
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
return prompt
def get_segment(self, features, offset=0):