Add task parameter

This commit is contained in:
Guillaume Klein
2023-02-13 21:26:25 +01:00
parent f56dfc6491
commit c86353d323

View File

@@ -24,6 +24,7 @@ class TranscriptionOptions(
collections.namedtuple( collections.namedtuple(
"TranscriptionOptions", "TranscriptionOptions",
( (
"task",
"beam_size", "beam_size",
"best_of", "best_of",
"patience", "patience",
@@ -93,6 +94,7 @@ class WhisperModel:
self, self,
input_file, input_file,
language=None, language=None,
task="transcribe",
beam_size=5, beam_size=5,
best_of=5, best_of=5,
patience=1, patience=1,
@@ -110,6 +112,7 @@ class WhisperModel:
input_file: Path to the input file or a file-like object. input_file: Path to the input file or a file-like object.
language: The language spoken in the audio. If not set, the language will be language: The language spoken in the audio. If not set, the language will be
detected in the first 30 seconds of audio. detected in the first 30 seconds of audio.
task: Task to execute (transcribe or translate).
beam_size: Beam size to use for decoding. beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature. best_of: Number of candidates when sampling with non-zero temperature.
patience: Beam search patience factor. patience: Beam search patience factor.
@@ -151,6 +154,7 @@ class WhisperModel:
language_probability = 1 language_probability = 1
options = TranscriptionOptions( options = TranscriptionOptions(
task=task,
beam_size=beam_size, beam_size=beam_size,
best_of=best_of, best_of=best_of,
patience=patience, patience=patience,
@@ -210,8 +214,12 @@ class WhisperModel:
previous_tokens = all_tokens[prompt_reset_since:] previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt( prompt = self.get_prompt(
language, previous_tokens, without_timestamps=options.without_timestamps language,
previous_tokens,
task=options.task,
without_timestamps=options.without_timestamps,
) )
result, temperature = self.generate_with_fallback(segment, prompt, options) result, temperature = self.generate_with_fallback(segment, prompt, options)
if ( if (
@@ -323,7 +331,13 @@ class WhisperModel:
return result, final_temperature return result, final_temperature
def get_prompt(self, language, previous_tokens, without_timestamps=False): def get_prompt(
self,
language,
previous_tokens,
task="transcribe",
without_timestamps=False,
):
prompt = [] prompt = []
if previous_tokens: if previous_tokens:
@@ -333,7 +347,7 @@ class WhisperModel:
prompt += [ prompt += [
self.tokenizer.token_to_id("<|startoftranscript|>"), self.tokenizer.token_to_id("<|startoftranscript|>"),
self.tokenizer.token_to_id("<|%s|>" % language), self.tokenizer.token_to_id("<|%s|>" % language),
self.tokenizer.token_to_id("<|transcribe|>"), self.tokenizer.token_to_id("<|%s|>" % task),
] ]
if without_timestamps: if without_timestamps: