Feature/add hotwords (#731)

* add hotword params

---------

Co-authored-by: jax <jax_builder@gamil.com>
This commit is contained in:
jax
2024-05-04 16:11:52 +08:00
committed by GitHub
parent 46080e584e
commit 847fec4492

View File

@@ -69,6 +69,7 @@ class TranscriptionOptions(NamedTuple):
max_new_tokens: Optional[int] max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]] clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float] hallucination_silence_threshold: Optional[float]
hotwords: Optional[str]
class TranscriptionInfo(NamedTuple): class TranscriptionInfo(NamedTuple):
@@ -220,6 +221,7 @@ class WhisperModel:
chunk_length: Optional[int] = None, chunk_length: Optional[int] = None,
clip_timestamps: Union[str, List[float]] = "0", clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None, hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = None, language_detection_threshold: Optional[float] = None,
language_detection_segments: int = 1, language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]: ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
@@ -284,10 +286,11 @@ class WhisperModel:
hallucination_silence_threshold: Optional[float] hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected (in seconds) when a possible hallucination is detected
hotwords:Optional text
add hotwords if set prefix it invalid
language_detection_threshold: If the maximum probability of the language tokens is higher language_detection_threshold: If the maximum probability of the language tokens is higher
than this value, the language is detected. than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection. language_detection_segments: Number of segments to consider for the language detection.
Returns: Returns:
A tuple with: A tuple with:
@@ -441,6 +444,7 @@ class WhisperModel:
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps, clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold, hallucination_silence_threshold=hallucination_silence_threshold,
hotwords=hotwords,
) )
segments = self.generate_segments(features, tokenizer, options, encoder_output) segments = self.generate_segments(features, tokenizer, options, encoder_output)
@@ -547,6 +551,7 @@ class WhisperModel:
previous_tokens, previous_tokens,
without_timestamps=options.without_timestamps, without_timestamps=options.without_timestamps,
prefix=options.prefix if seek == 0 else None, prefix=options.prefix if seek == 0 else None,
hotwords=options.hotwords,
) )
if seek > 0 or encoder_output is None: if seek > 0 or encoder_output is None:
@@ -939,11 +944,18 @@ class WhisperModel:
previous_tokens: List[int], previous_tokens: List[int],
without_timestamps: bool = False, without_timestamps: bool = False,
prefix: Optional[str] = None, prefix: Optional[str] = None,
hotwords: Optional[str] = None,
) -> List[int]: ) -> List[int]:
prompt = [] prompt = []
if previous_tokens: if previous_tokens or (hotwords and not prefix):
prompt.append(tokenizer.sot_prev) prompt.append(tokenizer.sot_prev)
if hotwords and not prefix:
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
if len(hotwords_tokens) >= self.max_length // 2:
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
prompt.extend(hotwords_tokens)
if previous_tokens:
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt.extend(tokenizer.sot_sequence) prompt.extend(tokenizer.sot_sequence)