Add task parameter
This commit is contained in:
@@ -24,6 +24,7 @@ class TranscriptionOptions(
|
||||
collections.namedtuple(
|
||||
"TranscriptionOptions",
|
||||
(
|
||||
"task",
|
||||
"beam_size",
|
||||
"best_of",
|
||||
"patience",
|
||||
@@ -93,6 +94,7 @@ class WhisperModel:
|
||||
self,
|
||||
input_file,
|
||||
language=None,
|
||||
task="transcribe",
|
||||
beam_size=5,
|
||||
best_of=5,
|
||||
patience=1,
|
||||
@@ -110,6 +112,7 @@ class WhisperModel:
|
||||
input_file: Path to the input file or a file-like object.
|
||||
language: The language spoken in the audio. If not set, the language will be
|
||||
detected in the first 30 seconds of audio.
|
||||
task: Task to execute (transcribe or translate).
|
||||
beam_size: Beam size to use for decoding.
|
||||
best_of: Number of candidates when sampling with non-zero temperature.
|
||||
patience: Beam search patience factor.
|
||||
@@ -151,6 +154,7 @@ class WhisperModel:
|
||||
language_probability = 1
|
||||
|
||||
options = TranscriptionOptions(
|
||||
task=task,
|
||||
beam_size=beam_size,
|
||||
best_of=best_of,
|
||||
patience=patience,
|
||||
@@ -210,8 +214,12 @@ class WhisperModel:
|
||||
|
||||
previous_tokens = all_tokens[prompt_reset_since:]
|
||||
prompt = self.get_prompt(
|
||||
language, previous_tokens, without_timestamps=options.without_timestamps
|
||||
language,
|
||||
previous_tokens,
|
||||
task=options.task,
|
||||
without_timestamps=options.without_timestamps,
|
||||
)
|
||||
|
||||
result, temperature = self.generate_with_fallback(segment, prompt, options)
|
||||
|
||||
if (
|
||||
@@ -323,7 +331,13 @@ class WhisperModel:
|
||||
|
||||
return result, final_temperature
|
||||
|
||||
def get_prompt(self, language, previous_tokens, without_timestamps=False):
|
||||
def get_prompt(
|
||||
self,
|
||||
language,
|
||||
previous_tokens,
|
||||
task="transcribe",
|
||||
without_timestamps=False,
|
||||
):
|
||||
prompt = []
|
||||
|
||||
if previous_tokens:
|
||||
@@ -333,7 +347,7 @@ class WhisperModel:
|
||||
prompt += [
|
||||
self.tokenizer.token_to_id("<|startoftranscript|>"),
|
||||
self.tokenizer.token_to_id("<|%s|>" % language),
|
||||
self.tokenizer.token_to_id("<|transcribe|>"),
|
||||
self.tokenizer.token_to_id("<|%s|>" % task),
|
||||
]
|
||||
|
||||
if without_timestamps:
|
||||
|
||||
Reference in New Issue
Block a user