Feature/add hotwords (#731)

* add hotword params

---------

Co-authored-by: jax <jax_builder@gamil.com>
This commit is contained in:
jax
2024-05-04 16:11:52 +08:00
committed by GitHub
parent 46080e584e
commit 847fec4492

View File

@@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple):
max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]
hotwords: Optional[str]
class TranscriptionInfo(NamedTuple):
@@ -220,6 +221,7 @@ class WhisperModel:
chunk_length: Optional[int] = None,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = None,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
@@ -284,10 +286,11 @@ class WhisperModel:
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected
hotwords:Optional text
add hotwords if set prefix it invalid
language_detection_threshold: If the maximum probability of the language tokens is higher
than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
Returns:
A tuple with:
@@ -441,6 +444,7 @@ class WhisperModel:
max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
hotwords=hotwords,
)
segments = self.generate_segments(features, tokenizer, options, encoder_output)
@@ -547,6 +551,7 @@ class WhisperModel:
previous_tokens,
without_timestamps=options.without_timestamps,
prefix=options.prefix if seek == 0 else None,
hotwords=options.hotwords,
)
if seek > 0 or encoder_output is None:
@@ -939,12 +944,19 @@ class WhisperModel:
previous_tokens: List[int],
without_timestamps: bool = False,
prefix: Optional[str] = None,
hotwords: Optional[str] = None,
) -> List[int]:
prompt = []
if previous_tokens:
if previous_tokens or (hotwords and not prefix):
prompt.append(tokenizer.sot_prev)
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
if hotwords and not prefix:
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
if len(hotwords_tokens) >= self.max_length // 2:
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
prompt.extend(hotwords_tokens)
if previous_tokens:
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt.extend(tokenizer.sot_sequence)