Add the initial_prompt parameter (#2)

* Add the initial_prompt parameter

* Add docstring
This commit is contained in:
Guillaume Klein
2023-02-12 11:42:21 +01:00
committed by GitHub
parent 23d2d64259
commit 7d1d0541c8

View File

@@ -32,6 +32,7 @@ class TranscriptionOptions(
"compression_ratio_threshold", "compression_ratio_threshold",
"condition_on_previous_text", "condition_on_previous_text",
"temperatures", "temperatures",
"initial_prompt",
), ),
) )
): ):
@@ -64,16 +65,9 @@ class WhisperModel:
) )
self.feature_extractor = FeatureExtractor() self.feature_extractor = FeatureExtractor()
self.decoder = tokenizers.decoders.ByteLevel() self.tokenizer = tokenizers.Tokenizer.from_pretrained("openai/whisper-tiny")
self.eot_id = self.tokenizer.token_to_id("<|endoftext|>")
with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file: self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1
self.ids_to_tokens = [line.rstrip("\n") for line in vocab_file]
self.tokens_to_ids = {
token: i for i, token in enumerate(self.ids_to_tokens)
}
self.eot_id = self.tokens_to_ids["<|endoftext|>"]
self.timestamp_begin_id = self.tokens_to_ids["<|notimestamps|>"] + 1
self.input_stride = 2 self.input_stride = 2
self.time_precision = 0.02 self.time_precision = 0.02
self.max_length = 448 self.max_length = 448
@@ -90,6 +84,7 @@ class WhisperModel:
log_prob_threshold=-1.0, log_prob_threshold=-1.0,
no_speech_threshold=0.6, no_speech_threshold=0.6,
condition_on_previous_text=True, condition_on_previous_text=True,
initial_prompt=None,
): ):
"""Transcribes an input file. """Transcribes an input file.
@@ -114,6 +109,7 @@ class WhisperModel:
as a prompt for the next window; disabling may make the text inconsistent across as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop, windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync. such as repetition looping or timestamps going out of sync.
initial_prompt: Optional text to provide as a prompt for the first window.
Returns: Returns:
A tuple with: A tuple with:
@@ -146,6 +142,7 @@ class WhisperModel:
temperatures=( temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature] temperature if isinstance(temperature, (list, tuple)) else [temperature]
), ),
initial_prompt=initial_prompt,
) )
segments = self.generate_segments(features, language, options) segments = self.generate_segments(features, language, options)
@@ -179,6 +176,13 @@ class WhisperModel:
all_tokens = [] all_tokens = []
prompt_reset_since = 0 prompt_reset_since = 0
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = self.tokenizer.encode(
initial_prompt, add_special_tokens=False
)
all_tokens.extend(initial_prompt_tokens.ids)
while offset < num_frames: while offset < num_frames:
time_offset = offset * self.feature_extractor.time_per_frame time_offset = offset * self.feature_extractor.time_per_frame
segment = self.get_segment(features, offset) segment = self.get_segment(features, offset)
@@ -253,11 +257,8 @@ class WhisperModel:
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_tokens)
def decode_text_tokens(self, tokens): def decode_text_tokens(self, tokens):
text_tokens = [ text_tokens = [token for token in tokens if token < self.eot_id]
self.ids_to_tokens[token] for token in tokens if token < self.eot_id return self.tokenizer.decode(text_tokens)
]
return self.decoder.decode(text_tokens)
def generate_with_fallback(self, segment, prompt, options): def generate_with_fallback(self, segment, prompt, options):
features = self.get_input(segment) features = self.get_input(segment)
@@ -304,13 +305,13 @@ class WhisperModel:
prompt = [] prompt = []
if previous_tokens: if previous_tokens:
prompt.append(self.tokens_to_ids["<|startofprev|>"]) prompt.append(self.tokenizer.token_to_id("<|startofprev|>"))
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt += [ prompt += [
self.tokens_to_ids["<|startoftranscript|>"], self.tokenizer.token_to_id("<|startoftranscript|>"),
self.tokens_to_ids["<|%s|>" % language], self.tokenizer.token_to_id("<|%s|>" % language),
self.tokens_to_ids["<|transcribe|>"], self.tokenizer.token_to_id("<|transcribe|>"),
] ]
return prompt return prompt