Prepend prefix tokens with the initial timestamp token (#358)
This commit is contained in:
@@ -686,6 +686,8 @@ class WhisperModel:
|
||||
prefix_tokens = tokenizer.encode(" " + prefix.strip())
|
||||
if len(prefix_tokens) >= self.max_length // 2:
|
||||
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
||||
if not without_timestamps:
|
||||
prompt.append(tokenizer.timestamp_begin)
|
||||
prompt.extend(prefix_tokens)
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
av==10.*
|
||||
ctranslate2>=3.10,<4
|
||||
ctranslate2>=3.17,<4
|
||||
huggingface_hub>=0.13
|
||||
tokenizers==0.13.*
|
||||
onnxruntime>=1.14,<2
|
||||
|
||||
@@ -34,6 +34,24 @@ def test_transcribe(jfk_path):
|
||||
assert segment.end == segment.words[-1].end
|
||||
|
||||
|
||||
def test_prefix_with_timestamps(jfk_path):
|
||||
model = WhisperModel("tiny")
|
||||
segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans")
|
||||
segments = list(segments)
|
||||
|
||||
assert len(segments) == 1
|
||||
|
||||
segment = segments[0]
|
||||
|
||||
assert segment.text == (
|
||||
" And so my fellow Americans ask not what your country can do for you, "
|
||||
"ask what you can do for your country."
|
||||
)
|
||||
|
||||
assert segment.start == 0
|
||||
assert 10 < segment.end < 11
|
||||
|
||||
|
||||
def test_vad(jfk_path):
|
||||
model = WhisperModel("tiny")
|
||||
segments, info = model.transcribe(
|
||||
|
||||
Reference in New Issue
Block a user