Allow users to input an Iterable of token ids into initial_prompt (#306)
* Allow users to input an Iterable of token ids into initial_prompt * Need to check for String first because string is also an Iterable
This commit is contained in:
@@ -52,7 +52,7 @@ class TranscriptionOptions(NamedTuple):
|
|||||||
compression_ratio_threshold: Optional[float]
|
compression_ratio_threshold: Optional[float]
|
||||||
condition_on_previous_text: bool
|
condition_on_previous_text: bool
|
||||||
temperatures: List[float]
|
temperatures: List[float]
|
||||||
initial_prompt: Optional[str]
|
initial_prompt: Optional[Union[str, Iterable[int]]]
|
||||||
prefix: Optional[str]
|
prefix: Optional[str]
|
||||||
suppress_blank: bool
|
suppress_blank: bool
|
||||||
suppress_tokens: Optional[List[int]]
|
suppress_tokens: Optional[List[int]]
|
||||||
@@ -170,7 +170,7 @@ class WhisperModel:
|
|||||||
log_prob_threshold: Optional[float] = -1.0,
|
log_prob_threshold: Optional[float] = -1.0,
|
||||||
no_speech_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
condition_on_previous_text: bool = True,
|
condition_on_previous_text: bool = True,
|
||||||
initial_prompt: Optional[str] = None,
|
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
suppress_blank: bool = True,
|
suppress_blank: bool = True,
|
||||||
suppress_tokens: Optional[List[int]] = [-1],
|
suppress_tokens: Optional[List[int]] = [-1],
|
||||||
@@ -208,7 +208,8 @@ class WhisperModel:
|
|||||||
as a prompt for the next window; disabling may make the text inconsistent across
|
as a prompt for the next window; disabling may make the text inconsistent across
|
||||||
windows, but the model becomes less prone to getting stuck in a failure loop,
|
windows, but the model becomes less prone to getting stuck in a failure loop,
|
||||||
such as repetition looping or timestamps going out of sync.
|
such as repetition looping or timestamps going out of sync.
|
||||||
initial_prompt: Optional text to provide as a prompt for the first window.
|
initial_prompt: Optional text string or iterable of token ids to provide as a
|
||||||
|
prompt for the first window.
|
||||||
prefix: Optional text to provide as a prefix for the first window.
|
prefix: Optional text to provide as a prefix for the first window.
|
||||||
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
||||||
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
||||||
@@ -361,9 +362,12 @@ class WhisperModel:
|
|||||||
prompt_reset_since = 0
|
prompt_reset_since = 0
|
||||||
|
|
||||||
if options.initial_prompt is not None:
|
if options.initial_prompt is not None:
|
||||||
initial_prompt = " " + options.initial_prompt.strip()
|
if isinstance(options.initial_prompt, str):
|
||||||
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
initial_prompt = " " + options.initial_prompt.strip()
|
||||||
all_tokens.extend(initial_prompt_tokens)
|
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
||||||
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
|
else:
|
||||||
|
all_tokens.extend(options.initial_prompt)
|
||||||
|
|
||||||
while seek < content_frames:
|
while seek < content_frames:
|
||||||
time_offset = seek * self.feature_extractor.time_per_frame
|
time_offset = seek * self.feature_extractor.time_per_frame
|
||||||
|
|||||||
Reference in New Issue
Block a user