Add prefix parameter

This commit is contained in:
Guillaume Klein
2023-02-27 12:09:40 +01:00
parent 528aa3e784
commit a4f1cc8f11

View File

@@ -36,6 +36,7 @@ class TranscriptionOptions(
"condition_on_previous_text",
"temperatures",
"initial_prompt",
"prefix",
"without_timestamps",
),
)
@@ -112,6 +113,7 @@ class WhisperModel:
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
prefix: Optional[str] = None,
without_timestamps: bool = False,
):
"""Transcribes an input file.
@@ -141,6 +143,7 @@ class WhisperModel:
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync.
initial_prompt: Optional text to provide as a prompt for the first window.
prefix: Optional text to provide as a prefix for the first window.
without_timestamps: Only sample text tokens.
Returns:
@@ -183,6 +186,7 @@ class WhisperModel:
temperature if isinstance(temperature, (list, tuple)) else [temperature]
),
initial_prompt=initial_prompt,
prefix=prefix,
without_timestamps=without_timestamps,
)
@@ -219,10 +223,8 @@ class WhisperModel:
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = self.tokenizer.encode(
initial_prompt, add_special_tokens=False
)
all_tokens.extend(initial_prompt_tokens.ids)
initial_prompt_tokens = self.encode_text(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
while offset < num_frames:
time_offset = offset * self.feature_extractor.time_per_frame
@@ -235,6 +237,7 @@ class WhisperModel:
previous_tokens,
task=options.task,
without_timestamps=options.without_timestamps,
prefix=options.prefix,
)
result, avg_log_prob, temperature = self.generate_with_fallback(
@@ -314,6 +317,9 @@ class WhisperModel:
if not options.condition_on_previous_text or temperature > 0.5:
prompt_reset_since = len(all_tokens)
def encode_text(self, text):
return self.tokenizer.encode(text, add_special_tokens=False).ids
def decode_text_tokens(self, tokens):
text_tokens = [token for token in tokens if token < self.eot_id]
return self.tokenizer.decode(text_tokens)
@@ -384,6 +390,7 @@ class WhisperModel:
previous_tokens,
task="transcribe",
without_timestamps=False,
prefix=None,
):
prompt = []
@@ -404,6 +411,12 @@ class WhisperModel:
if without_timestamps:
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
if prefix:
prefix_tokens = self.encode_text(" " + prefix.strip())
if len(prefix_tokens) >= self.max_length // 2:
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
prompt.extend(prefix_tokens)
return prompt
def get_segment(self, features, offset=0):