Add task parameter

This commit is contained in:
Guillaume Klein
2023-02-13 21:26:25 +01:00
parent f56dfc6491
commit c86353d323

View File

@@ -24,6 +24,7 @@ class TranscriptionOptions(
collections.namedtuple(
"TranscriptionOptions",
(
"task",
"beam_size",
"best_of",
"patience",
@@ -93,6 +94,7 @@ class WhisperModel:
self,
input_file,
language=None,
task="transcribe",
beam_size=5,
best_of=5,
patience=1,
@@ -110,6 +112,7 @@ class WhisperModel:
input_file: Path to the input file or a file-like object.
language: The language spoken in the audio. If not set, the language will be
detected in the first 30 seconds of audio.
task: Task to execute (transcribe or translate).
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor.
@@ -151,6 +154,7 @@ class WhisperModel:
language_probability = 1
options = TranscriptionOptions(
task=task,
beam_size=beam_size,
best_of=best_of,
patience=patience,
@@ -210,8 +214,12 @@ class WhisperModel:
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
language, previous_tokens, without_timestamps=options.without_timestamps
language,
previous_tokens,
task=options.task,
without_timestamps=options.without_timestamps,
)
result, temperature = self.generate_with_fallback(segment, prompt, options)
if (
@@ -323,7 +331,13 @@ class WhisperModel:
return result, final_temperature
def get_prompt(self, language, previous_tokens, without_timestamps=False):
def get_prompt(
self,
language,
previous_tokens,
task="transcribe",
without_timestamps=False,
):
prompt = []
if previous_tokens:
@@ -333,7 +347,7 @@ class WhisperModel:
prompt += [
self.tokenizer.token_to_id("<|startoftranscript|>"),
self.tokenizer.token_to_id("<|%s|>" % language),
self.tokenizer.token_to_id("<|transcribe|>"),
self.tokenizer.token_to_id("<|%s|>" % task),
]
if without_timestamps: