Add task parameter
This commit is contained in:
@@ -24,6 +24,7 @@ class TranscriptionOptions(
|
|||||||
collections.namedtuple(
|
collections.namedtuple(
|
||||||
"TranscriptionOptions",
|
"TranscriptionOptions",
|
||||||
(
|
(
|
||||||
|
"task",
|
||||||
"beam_size",
|
"beam_size",
|
||||||
"best_of",
|
"best_of",
|
||||||
"patience",
|
"patience",
|
||||||
@@ -93,6 +94,7 @@ class WhisperModel:
|
|||||||
self,
|
self,
|
||||||
input_file,
|
input_file,
|
||||||
language=None,
|
language=None,
|
||||||
|
task="transcribe",
|
||||||
beam_size=5,
|
beam_size=5,
|
||||||
best_of=5,
|
best_of=5,
|
||||||
patience=1,
|
patience=1,
|
||||||
@@ -110,6 +112,7 @@ class WhisperModel:
|
|||||||
input_file: Path to the input file or a file-like object.
|
input_file: Path to the input file or a file-like object.
|
||||||
language: The language spoken in the audio. If not set, the language will be
|
language: The language spoken in the audio. If not set, the language will be
|
||||||
detected in the first 30 seconds of audio.
|
detected in the first 30 seconds of audio.
|
||||||
|
task: Task to execute (transcribe or translate).
|
||||||
beam_size: Beam size to use for decoding.
|
beam_size: Beam size to use for decoding.
|
||||||
best_of: Number of candidates when sampling with non-zero temperature.
|
best_of: Number of candidates when sampling with non-zero temperature.
|
||||||
patience: Beam search patience factor.
|
patience: Beam search patience factor.
|
||||||
@@ -151,6 +154,7 @@ class WhisperModel:
|
|||||||
language_probability = 1
|
language_probability = 1
|
||||||
|
|
||||||
options = TranscriptionOptions(
|
options = TranscriptionOptions(
|
||||||
|
task=task,
|
||||||
beam_size=beam_size,
|
beam_size=beam_size,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
patience=patience,
|
patience=patience,
|
||||||
@@ -210,8 +214,12 @@ class WhisperModel:
|
|||||||
|
|
||||||
previous_tokens = all_tokens[prompt_reset_since:]
|
previous_tokens = all_tokens[prompt_reset_since:]
|
||||||
prompt = self.get_prompt(
|
prompt = self.get_prompt(
|
||||||
language, previous_tokens, without_timestamps=options.without_timestamps
|
language,
|
||||||
|
previous_tokens,
|
||||||
|
task=options.task,
|
||||||
|
without_timestamps=options.without_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
result, temperature = self.generate_with_fallback(segment, prompt, options)
|
result, temperature = self.generate_with_fallback(segment, prompt, options)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -323,7 +331,13 @@ class WhisperModel:
|
|||||||
|
|
||||||
return result, final_temperature
|
return result, final_temperature
|
||||||
|
|
||||||
def get_prompt(self, language, previous_tokens, without_timestamps=False):
|
def get_prompt(
|
||||||
|
self,
|
||||||
|
language,
|
||||||
|
previous_tokens,
|
||||||
|
task="transcribe",
|
||||||
|
without_timestamps=False,
|
||||||
|
):
|
||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
if previous_tokens:
|
if previous_tokens:
|
||||||
@@ -333,7 +347,7 @@ class WhisperModel:
|
|||||||
prompt += [
|
prompt += [
|
||||||
self.tokenizer.token_to_id("<|startoftranscript|>"),
|
self.tokenizer.token_to_id("<|startoftranscript|>"),
|
||||||
self.tokenizer.token_to_id("<|%s|>" % language),
|
self.tokenizer.token_to_id("<|%s|>" % language),
|
||||||
self.tokenizer.token_to_id("<|transcribe|>"),
|
self.tokenizer.token_to_id("<|%s|>" % task),
|
||||||
]
|
]
|
||||||
|
|
||||||
if without_timestamps:
|
if without_timestamps:
|
||||||
|
|||||||
Reference in New Issue
Block a user