Allow users to input an Iterable of token ids into initial_prompt (#306)

* Allow users to input an Iterable of token ids into initial_prompt

* Need to check for String first because string is also an Iterable
This commit is contained in:
FlippFuzz
2023-06-21 20:46:20 +08:00
committed by GitHub
parent efc4f61d85
commit fee52c9229

View File

@@ -52,7 +52,7 @@ class TranscriptionOptions(NamedTuple):
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
temperatures: List[float]
initial_prompt: Optional[str]
initial_prompt: Optional[Union[str, Iterable[int]]]
prefix: Optional[str]
suppress_blank: bool
suppress_tokens: Optional[List[int]]
@@ -170,7 +170,7 @@ class WhisperModel:
log_prob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
prefix: Optional[str] = None,
suppress_blank: bool = True,
suppress_tokens: Optional[List[int]] = [-1],
@@ -208,7 +208,8 @@ class WhisperModel:
as a prompt for the next window; disabling may make the text inconsistent across
windows, but the model becomes less prone to getting stuck in a failure loop,
such as repetition looping or timestamps going out of sync.
initial_prompt: Optional text to provide as a prompt for the first window.
initial_prompt: Optional text string or iterable of token ids to provide as a
prompt for the first window.
prefix: Optional text to provide as a prefix for the first window.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
@@ -361,9 +362,12 @@ class WhisperModel:
prompt_reset_since = 0
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
if isinstance(options.initial_prompt, str):
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
else:
all_tokens.extend(options.initial_prompt)
while seek < content_frames:
time_offset = seek * self.feature_extractor.time_per_frame