Add prefix parameter
This commit is contained in:
@@ -36,6 +36,7 @@ class TranscriptionOptions(
|
||||
"condition_on_previous_text",
|
||||
"temperatures",
|
||||
"initial_prompt",
|
||||
"prefix",
|
||||
"without_timestamps",
|
||||
),
|
||||
)
|
||||
@@ -112,6 +113,7 @@ class WhisperModel:
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
without_timestamps: bool = False,
|
||||
):
|
||||
"""Transcribes an input file.
|
||||
@@ -141,6 +143,7 @@ class WhisperModel:
|
||||
windows, but the model becomes less prone to getting stuck in a failure loop,
|
||||
such as repetition looping or timestamps going out of sync.
|
||||
initial_prompt: Optional text to provide as a prompt for the first window.
|
||||
prefix: Optional text to provide as a prefix for the first window.
|
||||
without_timestamps: Only sample text tokens.
|
||||
|
||||
Returns:
|
||||
@@ -183,6 +186,7 @@ class WhisperModel:
|
||||
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
||||
),
|
||||
initial_prompt=initial_prompt,
|
||||
prefix=prefix,
|
||||
without_timestamps=without_timestamps,
|
||||
)
|
||||
|
||||
@@ -219,10 +223,8 @@ class WhisperModel:
|
||||
|
||||
if options.initial_prompt is not None:
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = self.tokenizer.encode(
|
||||
initial_prompt, add_special_tokens=False
|
||||
)
|
||||
all_tokens.extend(initial_prompt_tokens.ids)
|
||||
initial_prompt_tokens = self.encode_text(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
|
||||
while offset < num_frames:
|
||||
time_offset = offset * self.feature_extractor.time_per_frame
|
||||
@@ -235,6 +237,7 @@ class WhisperModel:
|
||||
previous_tokens,
|
||||
task=options.task,
|
||||
without_timestamps=options.without_timestamps,
|
||||
prefix=options.prefix,
|
||||
)
|
||||
|
||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||
@@ -314,6 +317,9 @@ class WhisperModel:
|
||||
if not options.condition_on_previous_text or temperature > 0.5:
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
def encode_text(self, text):
|
||||
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||
|
||||
def decode_text_tokens(self, tokens):
|
||||
text_tokens = [token for token in tokens if token < self.eot_id]
|
||||
return self.tokenizer.decode(text_tokens)
|
||||
@@ -384,6 +390,7 @@ class WhisperModel:
|
||||
previous_tokens,
|
||||
task="transcribe",
|
||||
without_timestamps=False,
|
||||
prefix=None,
|
||||
):
|
||||
prompt = []
|
||||
|
||||
@@ -404,6 +411,12 @@ class WhisperModel:
|
||||
if without_timestamps:
|
||||
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
|
||||
|
||||
if prefix:
|
||||
prefix_tokens = self.encode_text(" " + prefix.strip())
|
||||
if len(prefix_tokens) >= self.max_length // 2:
|
||||
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
||||
prompt.extend(prefix_tokens)
|
||||
|
||||
return prompt
|
||||
|
||||
def get_segment(self, features, offset=0):
|
||||
|
||||
Reference in New Issue
Block a user