From fee52c922904d18a143faa9d89f3b42caf029e9e Mon Sep 17 00:00:00 2001 From: FlippFuzz <41221030+FlippFuzz@users.noreply.github.com> Date: Wed, 21 Jun 2023 20:46:20 +0800 Subject: [PATCH] Allow users to input an Iterable of token ids into initial_prompt (#306) * Allow users to input an Iterable of token ids into initial_prompt * Need to check for String first because string is also an Iterable --- faster_whisper/transcribe.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index aee13b5..71b0ea1 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -52,7 +52,7 @@ class TranscriptionOptions(NamedTuple): compression_ratio_threshold: Optional[float] condition_on_previous_text: bool temperatures: List[float] - initial_prompt: Optional[str] + initial_prompt: Optional[Union[str, Iterable[int]]] prefix: Optional[str] suppress_blank: bool suppress_tokens: Optional[List[int]] @@ -170,7 +170,7 @@ class WhisperModel: log_prob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, - initial_prompt: Optional[str] = None, + initial_prompt: Optional[Union[str, Iterable[int]]] = None, prefix: Optional[str] = None, suppress_blank: bool = True, suppress_tokens: Optional[List[int]] = [-1], @@ -208,7 +208,8 @@ class WhisperModel: as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - initial_prompt: Optional text to provide as a prompt for the first window. + initial_prompt: Optional text string or iterable of token ids to provide as a + prompt for the first window. prefix: Optional text to provide as a prefix for the first window. suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set @@ -361,9 +362,12 @@ class WhisperModel: prompt_reset_since = 0 if options.initial_prompt is not None: - initial_prompt = " " + options.initial_prompt.strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) - all_tokens.extend(initial_prompt_tokens) + if isinstance(options.initial_prompt, str): + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + else: + all_tokens.extend(options.initial_prompt) while seek < content_frames: time_offset = seek * self.feature_extractor.time_per_frame