Allow users to input an Iterable of token ids into initial_prompt (#306)
* Allow users to input an Iterable of token ids into initial_prompt * Need to check for String first because string is also an Iterable
This commit is contained in:
@@ -52,7 +52,7 @@ class TranscriptionOptions(NamedTuple):
|
||||
compression_ratio_threshold: Optional[float]
|
||||
condition_on_previous_text: bool
|
||||
temperatures: List[float]
|
||||
initial_prompt: Optional[str]
|
||||
initial_prompt: Optional[Union[str, Iterable[int]]]
|
||||
prefix: Optional[str]
|
||||
suppress_blank: bool
|
||||
suppress_tokens: Optional[List[int]]
|
||||
@@ -170,7 +170,7 @@ class WhisperModel:
|
||||
log_prob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
suppress_blank: bool = True,
|
||||
suppress_tokens: Optional[List[int]] = [-1],
|
||||
@@ -208,7 +208,8 @@ class WhisperModel:
|
||||
as a prompt for the next window; disabling may make the text inconsistent across
|
||||
windows, but the model becomes less prone to getting stuck in a failure loop,
|
||||
such as repetition looping or timestamps going out of sync.
|
||||
initial_prompt: Optional text to provide as a prompt for the first window.
|
||||
initial_prompt: Optional text string or iterable of token ids to provide as a
|
||||
prompt for the first window.
|
||||
prefix: Optional text to provide as a prefix for the first window.
|
||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||
@@ -361,9 +362,12 @@ class WhisperModel:
|
||||
prompt_reset_since = 0
|
||||
|
||||
if options.initial_prompt is not None:
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
if isinstance(options.initial_prompt, str):
|
||||
initial_prompt = " " + options.initial_prompt.strip()
|
||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
else:
|
||||
all_tokens.extend(options.initial_prompt)
|
||||
|
||||
while seek < content_frames:
|
||||
time_offset = seek * self.feature_extractor.time_per_frame
|
||||
|
||||
Reference in New Issue
Block a user