Add prefix parameter
This commit is contained in:
@@ -36,6 +36,7 @@ class TranscriptionOptions(
|
|||||||
"condition_on_previous_text",
|
"condition_on_previous_text",
|
||||||
"temperatures",
|
"temperatures",
|
||||||
"initial_prompt",
|
"initial_prompt",
|
||||||
|
"prefix",
|
||||||
"without_timestamps",
|
"without_timestamps",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -112,6 +113,7 @@ class WhisperModel:
|
|||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[str] = None,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
without_timestamps: bool = False,
|
without_timestamps: bool = False,
|
||||||
):
|
):
|
||||||
"""Transcribes an input file.
|
"""Transcribes an input file.
|
||||||
@@ -141,6 +143,7 @@ class WhisperModel:
|
|||||||
windows, but the model becomes less prone to getting stuck in a failure loop,
|
windows, but the model becomes less prone to getting stuck in a failure loop,
|
||||||
such as repetition looping or timestamps going out of sync.
|
such as repetition looping or timestamps going out of sync.
|
||||||
initial_prompt: Optional text to provide as a prompt for the first window.
|
initial_prompt: Optional text to provide as a prompt for the first window.
|
||||||
|
prefix: Optional text to provide as a prefix for the first window.
|
||||||
without_timestamps: Only sample text tokens.
|
without_timestamps: Only sample text tokens.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -183,6 +186,7 @@ class WhisperModel:
|
|||||||
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
||||||
),
|
),
|
||||||
initial_prompt=initial_prompt,
|
initial_prompt=initial_prompt,
|
||||||
|
prefix=prefix,
|
||||||
without_timestamps=without_timestamps,
|
without_timestamps=without_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -219,10 +223,8 @@ class WhisperModel:
|
|||||||
|
|
||||||
if options.initial_prompt is not None:
|
if options.initial_prompt is not None:
|
||||||
initial_prompt = " " + options.initial_prompt.strip()
|
initial_prompt = " " + options.initial_prompt.strip()
|
||||||
initial_prompt_tokens = self.tokenizer.encode(
|
initial_prompt_tokens = self.encode_text(initial_prompt)
|
||||||
initial_prompt, add_special_tokens=False
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
)
|
|
||||||
all_tokens.extend(initial_prompt_tokens.ids)
|
|
||||||
|
|
||||||
while offset < num_frames:
|
while offset < num_frames:
|
||||||
time_offset = offset * self.feature_extractor.time_per_frame
|
time_offset = offset * self.feature_extractor.time_per_frame
|
||||||
@@ -235,6 +237,7 @@ class WhisperModel:
|
|||||||
previous_tokens,
|
previous_tokens,
|
||||||
task=options.task,
|
task=options.task,
|
||||||
without_timestamps=options.without_timestamps,
|
without_timestamps=options.without_timestamps,
|
||||||
|
prefix=options.prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
result, avg_log_prob, temperature = self.generate_with_fallback(
|
result, avg_log_prob, temperature = self.generate_with_fallback(
|
||||||
@@ -314,6 +317,9 @@ class WhisperModel:
|
|||||||
if not options.condition_on_previous_text or temperature > 0.5:
|
if not options.condition_on_previous_text or temperature > 0.5:
|
||||||
prompt_reset_since = len(all_tokens)
|
prompt_reset_since = len(all_tokens)
|
||||||
|
|
||||||
|
def encode_text(self, text):
|
||||||
|
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
||||||
|
|
||||||
def decode_text_tokens(self, tokens):
|
def decode_text_tokens(self, tokens):
|
||||||
text_tokens = [token for token in tokens if token < self.eot_id]
|
text_tokens = [token for token in tokens if token < self.eot_id]
|
||||||
return self.tokenizer.decode(text_tokens)
|
return self.tokenizer.decode(text_tokens)
|
||||||
@@ -384,6 +390,7 @@ class WhisperModel:
|
|||||||
previous_tokens,
|
previous_tokens,
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
without_timestamps=False,
|
without_timestamps=False,
|
||||||
|
prefix=None,
|
||||||
):
|
):
|
||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
@@ -404,6 +411,12 @@ class WhisperModel:
|
|||||||
if without_timestamps:
|
if without_timestamps:
|
||||||
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
|
prompt.append(self.tokenizer.token_to_id("<|notimestamps|>"))
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
prefix_tokens = self.encode_text(" " + prefix.strip())
|
||||||
|
if len(prefix_tokens) >= self.max_length // 2:
|
||||||
|
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
||||||
|
prompt.extend(prefix_tokens)
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def get_segment(self, features, offset=0):
|
def get_segment(self, features, offset=0):
|
||||||
|
|||||||
Reference in New Issue
Block a user