Feature/add hotwords (#731)
* add hotword params --------- Co-authored-by: jax <jax_builder@gamil.com>
This commit is contained in:
@@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple):
|
|||||||
max_new_tokens: Optional[int]
|
max_new_tokens: Optional[int]
|
||||||
clip_timestamps: Union[str, List[float]]
|
clip_timestamps: Union[str, List[float]]
|
||||||
hallucination_silence_threshold: Optional[float]
|
hallucination_silence_threshold: Optional[float]
|
||||||
|
hotwords: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionInfo(NamedTuple):
|
class TranscriptionInfo(NamedTuple):
|
||||||
@@ -220,6 +221,7 @@ class WhisperModel:
|
|||||||
chunk_length: Optional[int] = None,
|
chunk_length: Optional[int] = None,
|
||||||
clip_timestamps: Union[str, List[float]] = "0",
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
hallucination_silence_threshold: Optional[float] = None,
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
|
hotwords: Optional[str] = None,
|
||||||
language_detection_threshold: Optional[float] = None,
|
language_detection_threshold: Optional[float] = None,
|
||||||
language_detection_segments: int = 1,
|
language_detection_segments: int = 1,
|
||||||
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
||||||
@@ -284,10 +286,11 @@ class WhisperModel:
|
|||||||
hallucination_silence_threshold: Optional[float]
|
hallucination_silence_threshold: Optional[float]
|
||||||
When word_timestamps is True, skip silent periods longer than this threshold
|
When word_timestamps is True, skip silent periods longer than this threshold
|
||||||
(in seconds) when a possible hallucination is detected
|
(in seconds) when a possible hallucination is detected
|
||||||
|
hotwords:Optional text
|
||||||
|
add hotwords if set prefix it invalid
|
||||||
language_detection_threshold: If the maximum probability of the language tokens is higher
|
language_detection_threshold: If the maximum probability of the language tokens is higher
|
||||||
than this value, the language is detected.
|
than this value, the language is detected.
|
||||||
language_detection_segments: Number of segments to consider for the language detection.
|
language_detection_segments: Number of segments to consider for the language detection.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple with:
|
A tuple with:
|
||||||
|
|
||||||
@@ -441,6 +444,7 @@ class WhisperModel:
|
|||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
clip_timestamps=clip_timestamps,
|
clip_timestamps=clip_timestamps,
|
||||||
hallucination_silence_threshold=hallucination_silence_threshold,
|
hallucination_silence_threshold=hallucination_silence_threshold,
|
||||||
|
hotwords=hotwords,
|
||||||
)
|
)
|
||||||
|
|
||||||
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
||||||
@@ -547,6 +551,7 @@ class WhisperModel:
|
|||||||
previous_tokens,
|
previous_tokens,
|
||||||
without_timestamps=options.without_timestamps,
|
without_timestamps=options.without_timestamps,
|
||||||
prefix=options.prefix if seek == 0 else None,
|
prefix=options.prefix if seek == 0 else None,
|
||||||
|
hotwords=options.hotwords,
|
||||||
)
|
)
|
||||||
|
|
||||||
if seek > 0 or encoder_output is None:
|
if seek > 0 or encoder_output is None:
|
||||||
@@ -939,11 +944,18 @@ class WhisperModel:
|
|||||||
previous_tokens: List[int],
|
previous_tokens: List[int],
|
||||||
without_timestamps: bool = False,
|
without_timestamps: bool = False,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
|
hotwords: Optional[str] = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
if previous_tokens:
|
if previous_tokens or (hotwords and not prefix):
|
||||||
prompt.append(tokenizer.sot_prev)
|
prompt.append(tokenizer.sot_prev)
|
||||||
|
if hotwords and not prefix:
|
||||||
|
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
|
||||||
|
if len(hotwords_tokens) >= self.max_length // 2:
|
||||||
|
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
|
||||||
|
prompt.extend(hotwords_tokens)
|
||||||
|
if previous_tokens:
|
||||||
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
||||||
|
|
||||||
prompt.extend(tokenizer.sot_sequence)
|
prompt.extend(tokenizer.sot_sequence)
|
||||||
|
|||||||
Reference in New Issue
Block a user