Add the initial_prompt parameter (#2)
* Add the initial_prompt parameter * Add docstring
This commit is contained in:
@@ -32,6 +32,7 @@ class TranscriptionOptions(
|
|||||||
"compression_ratio_threshold",
|
"compression_ratio_threshold",
|
||||||
"condition_on_previous_text",
|
"condition_on_previous_text",
|
||||||
"temperatures",
|
"temperatures",
|
||||||
|
"initial_prompt",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@@ -64,16 +65,9 @@ class WhisperModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.feature_extractor = FeatureExtractor()
|
self.feature_extractor = FeatureExtractor()
|
||||||
self.decoder = tokenizers.decoders.ByteLevel()
|
self.tokenizer = tokenizers.Tokenizer.from_pretrained("openai/whisper-tiny")
|
||||||
|
self.eot_id = self.tokenizer.token_to_id("<|endoftext|>")
|
||||||
with open(os.path.join(model_path, "vocabulary.txt")) as vocab_file:
|
self.timestamp_begin_id = self.tokenizer.token_to_id("<|notimestamps|>") + 1
|
||||||
self.ids_to_tokens = [line.rstrip("\n") for line in vocab_file]
|
|
||||||
self.tokens_to_ids = {
|
|
||||||
token: i for i, token in enumerate(self.ids_to_tokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
self.eot_id = self.tokens_to_ids["<|endoftext|>"]
|
|
||||||
self.timestamp_begin_id = self.tokens_to_ids["<|notimestamps|>"] + 1
|
|
||||||
self.input_stride = 2
|
self.input_stride = 2
|
||||||
self.time_precision = 0.02
|
self.time_precision = 0.02
|
||||||
self.max_length = 448
|
self.max_length = 448
|
||||||
@@ -90,6 +84,7 @@ class WhisperModel:
|
|||||||
log_prob_threshold=-1.0,
|
log_prob_threshold=-1.0,
|
||||||
no_speech_threshold=0.6,
|
no_speech_threshold=0.6,
|
||||||
condition_on_previous_text=True,
|
condition_on_previous_text=True,
|
||||||
|
initial_prompt=None,
|
||||||
):
|
):
|
||||||
"""Transcribes an input file.
|
"""Transcribes an input file.
|
||||||
|
|
||||||
@@ -114,6 +109,7 @@ class WhisperModel:
|
|||||||
as a prompt for the next window; disabling may make the text inconsistent across
|
as a prompt for the next window; disabling may make the text inconsistent across
|
||||||
windows, but the model becomes less prone to getting stuck in a failure loop,
|
windows, but the model becomes less prone to getting stuck in a failure loop,
|
||||||
such as repetition looping or timestamps going out of sync.
|
such as repetition looping or timestamps going out of sync.
|
||||||
|
initial_prompt: Optional text to provide as a prompt for the first window.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple with:
|
A tuple with:
|
||||||
@@ -146,6 +142,7 @@ class WhisperModel:
|
|||||||
temperatures=(
|
temperatures=(
|
||||||
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
||||||
),
|
),
|
||||||
|
initial_prompt=initial_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
segments = self.generate_segments(features, language, options)
|
segments = self.generate_segments(features, language, options)
|
||||||
@@ -179,6 +176,13 @@ class WhisperModel:
|
|||||||
all_tokens = []
|
all_tokens = []
|
||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
|
if options.initial_prompt is not None:
|
||||||
|
initial_prompt = " " + options.initial_prompt.strip()
|
||||||
|
initial_prompt_tokens = self.tokenizer.encode(
|
||||||
|
initial_prompt, add_special_tokens=False
|
||||||
|
)
|
||||||
|
all_tokens.extend(initial_prompt_tokens.ids)
|
||||||
|
|
||||||
while offset < num_frames:
|
while offset < num_frames:
|
||||||
time_offset = offset * self.feature_extractor.time_per_frame
|
time_offset = offset * self.feature_extractor.time_per_frame
|
||||||
segment = self.get_segment(features, offset)
|
segment = self.get_segment(features, offset)
|
||||||
@@ -253,11 +257,8 @@ class WhisperModel:
|
|||||||
prompt_reset_since = len(all_tokens)
|
prompt_reset_since = len(all_tokens)
|
||||||
|
|
||||||
def decode_text_tokens(self, tokens):
|
def decode_text_tokens(self, tokens):
|
||||||
text_tokens = [
|
text_tokens = [token for token in tokens if token < self.eot_id]
|
||||||
self.ids_to_tokens[token] for token in tokens if token < self.eot_id
|
return self.tokenizer.decode(text_tokens)
|
||||||
]
|
|
||||||
|
|
||||||
return self.decoder.decode(text_tokens)
|
|
||||||
|
|
||||||
def generate_with_fallback(self, segment, prompt, options):
|
def generate_with_fallback(self, segment, prompt, options):
|
||||||
features = self.get_input(segment)
|
features = self.get_input(segment)
|
||||||
@@ -304,13 +305,13 @@ class WhisperModel:
|
|||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
if previous_tokens:
|
if previous_tokens:
|
||||||
prompt.append(self.tokens_to_ids["<|startofprev|>"])
|
prompt.append(self.tokenizer.token_to_id("<|startofprev|>"))
|
||||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||||
|
|
||||||
prompt += [
|
prompt += [
|
||||||
self.tokens_to_ids["<|startoftranscript|>"],
|
self.tokenizer.token_to_id("<|startoftranscript|>"),
|
||||||
self.tokens_to_ids["<|%s|>" % language],
|
self.tokenizer.token_to_id("<|%s|>" % language),
|
||||||
self.tokens_to_ids["<|transcribe|>"],
|
self.tokenizer.token_to_id("<|transcribe|>"),
|
||||||
]
|
]
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user