nocaptions -> nospeech to match the paper figure
This commit is contained in:
@@ -108,7 +108,7 @@ class DecodingResult:
|
|||||||
tokens: List[int] = field(default_factory=list)
|
tokens: List[int] = field(default_factory=list)
|
||||||
text: str = ""
|
text: str = ""
|
||||||
avg_logprob: float = np.nan
|
avg_logprob: float = np.nan
|
||||||
no_caption_prob: float = np.nan
|
no_speech_prob: float = np.nan
|
||||||
temperature: float = np.nan
|
temperature: float = np.nan
|
||||||
compression_ratio: float = np.nan
|
compression_ratio: float = np.nan
|
||||||
|
|
||||||
@@ -543,9 +543,9 @@ class DecodingTask:
|
|||||||
suppress_tokens.extend(
|
suppress_tokens.extend(
|
||||||
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
||||||
)
|
)
|
||||||
if self.tokenizer.no_captions is not None:
|
if self.tokenizer.no_speech is not None:
|
||||||
# no-captions probability is collected separately
|
# no-speech probability is collected separately
|
||||||
suppress_tokens.append(self.tokenizer.no_captions)
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
|
|
||||||
return tuple(sorted(set(suppress_tokens)))
|
return tuple(sorted(set(suppress_tokens)))
|
||||||
|
|
||||||
@@ -580,15 +580,15 @@ class DecodingTask:
|
|||||||
assert audio_features.shape[0] == tokens.shape[0]
|
assert audio_features.shape[0] == tokens.shape[0]
|
||||||
n_batch = tokens.shape[0]
|
n_batch = tokens.shape[0]
|
||||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||||
no_caption_probs = [np.nan] * n_batch
|
no_speech_probs = [np.nan] * n_batch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for i in range(self.sample_len):
|
for i in range(self.sample_len):
|
||||||
logits = self.inference.logits(tokens, audio_features)
|
logits = self.inference.logits(tokens, audio_features)
|
||||||
|
|
||||||
if i == 0 and self.tokenizer.no_captions is not None: # save no_caption_probs
|
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
||||||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||||
no_caption_probs = probs_at_sot[:, self.tokenizer.no_captions].tolist()
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||||
|
|
||||||
# now we need to consider the logits at the last token only
|
# now we need to consider the logits at the last token only
|
||||||
logits = logits[:, -1]
|
logits = logits[:, -1]
|
||||||
@@ -605,7 +605,7 @@ class DecodingTask:
|
|||||||
finally:
|
finally:
|
||||||
self.inference.cleanup_caching()
|
self.inference.cleanup_caching()
|
||||||
|
|
||||||
return tokens, sum_logprobs, no_caption_probs
|
return tokens, sum_logprobs, no_speech_probs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run(self, mel: Tensor) -> List[DecodingResult]:
|
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||||
@@ -629,12 +629,12 @@ class DecodingTask:
|
|||||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
# call the main sampling loop
|
# call the main sampling loop
|
||||||
tokens, sum_logprobs, no_caption_probs = self._main_loop(audio_features, tokens)
|
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||||
|
|
||||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||||
audio_features = audio_features[:: self.n_group]
|
audio_features = audio_features[:: self.n_group]
|
||||||
no_caption_probs = no_caption_probs[:: self.n_group]
|
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||||
assert audio_features.shape[0] == len(no_caption_probs) == n_audio
|
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||||
|
|
||||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||||
@@ -653,7 +653,7 @@ class DecodingTask:
|
|||||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||||
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
||||||
|
|
||||||
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_caption_probs)
|
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
||||||
if len(set(map(len, fields))) != 1:
|
if len(set(map(len, fields))) != 1:
|
||||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||||
|
|
||||||
@@ -664,11 +664,11 @@ class DecodingTask:
|
|||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
text=text,
|
text=text,
|
||||||
avg_logprob=avg_logprob,
|
avg_logprob=avg_logprob,
|
||||||
no_caption_prob=no_caption_prob,
|
no_speech_prob=no_speech_prob,
|
||||||
temperature=self.options.temperature,
|
temperature=self.options.temperature,
|
||||||
compression_ratio=compression_ratio(text),
|
compression_ratio=compression_ratio(text),
|
||||||
)
|
)
|
||||||
for text, language, tokens, features, avg_logprob, no_caption_prob in zip(*fields)
|
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -178,8 +178,8 @@ class Tokenizer:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def no_captions(self) -> int:
|
def no_speech(self) -> int:
|
||||||
return self._get_single_token_id("<|nocaptions|>")
|
return self._get_single_token_id("<|nospeech|>")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@@ -283,7 +283,7 @@ def build_tokenizer(name: str = "gpt2"):
|
|||||||
"<|transcribe|>",
|
"<|transcribe|>",
|
||||||
"<|startoflm|>",
|
"<|startoflm|>",
|
||||||
"<|startofprev|>",
|
"<|startofprev|>",
|
||||||
"<|nocaptions|>",
|
"<|nospeech|>",
|
||||||
"<|notimestamps|>",
|
"<|notimestamps|>",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def transcribe(
|
|||||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
compression_ratio_threshold: Optional[float] = 2.4,
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
logprob_threshold: Optional[float] = -1.0,
|
logprob_threshold: Optional[float] = -1.0,
|
||||||
no_captions_threshold: Optional[float] = 0.6,
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
**decode_options,
|
**decode_options,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -50,8 +50,8 @@ def transcribe(
|
|||||||
logprob_threshold: float
|
logprob_threshold: float
|
||||||
If the average log probability over sampled tokens is below this value, treat as failed
|
If the average log probability over sampled tokens is below this value, treat as failed
|
||||||
|
|
||||||
no_captions_threshold: float
|
no_speech_threshold: float
|
||||||
If the no_captions probability is higher than this value AND the average log probability
|
If the no_speech probability is higher than this value AND the average log probability
|
||||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||||
|
|
||||||
decode_options: dict
|
decode_options: dict
|
||||||
@@ -148,7 +148,7 @@ def transcribe(
|
|||||||
"temperature": result.temperature,
|
"temperature": result.temperature,
|
||||||
"avg_logprob": result.avg_logprob,
|
"avg_logprob": result.avg_logprob,
|
||||||
"compression_ratio": result.compression_ratio,
|
"compression_ratio": result.compression_ratio,
|
||||||
"no_caption_prob": result.no_caption_prob,
|
"no_speech_prob": result.no_speech_prob,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
@@ -163,11 +163,11 @@ def transcribe(
|
|||||||
result = decode_with_fallback(segment)[0]
|
result = decode_with_fallback(segment)[0]
|
||||||
tokens = torch.tensor(result.tokens)
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
if no_captions_threshold is not None:
|
if no_speech_threshold is not None:
|
||||||
# no voice activity check
|
# no voice activity check
|
||||||
should_skip = result.no_caption_prob > no_captions_threshold
|
should_skip = result.no_speech_prob > no_speech_threshold
|
||||||
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
||||||
# don't skip if the logprob is high enough, despite the no_captions_prob
|
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||||
should_skip = False
|
should_skip = False
|
||||||
|
|
||||||
if should_skip:
|
if should_skip:
|
||||||
@@ -249,7 +249,7 @@ def cli():
|
|||||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||||
parser.add_argument("--no_caption_threshold", type=optional_float, default=0.6, help="if the probability of the <|nocaptions|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
model_name: str = args.pop("model")
|
model_name: str = args.pop("model")
|
||||||
@@ -261,12 +261,8 @@ def cli():
|
|||||||
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||||
args["language"] = "en"
|
args["language"] = "en"
|
||||||
|
|
||||||
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
|
||||||
compression_ratio_threshold = args.pop("compression_ratio_threshold")
|
|
||||||
logprob_threshold = args.pop("logprob_threshold")
|
|
||||||
no_caption_threshold = args.pop("no_caption_threshold")
|
|
||||||
|
|
||||||
temperature = args.pop("temperature")
|
temperature = args.pop("temperature")
|
||||||
|
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
||||||
if temperature_increment_on_fallback is not None:
|
if temperature_increment_on_fallback is not None:
|
||||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
||||||
else:
|
else:
|
||||||
@@ -276,15 +272,7 @@ def cli():
|
|||||||
model = load_model(model_name, device=device)
|
model = load_model(model_name, device=device)
|
||||||
|
|
||||||
for audio_path in args.pop("audio"):
|
for audio_path in args.pop("audio"):
|
||||||
result = transcribe(
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
model,
|
|
||||||
audio_path,
|
|
||||||
temperature=temperature,
|
|
||||||
compression_ratio_threshold=compression_ratio_threshold,
|
|
||||||
logprob_threshold=logprob_threshold,
|
|
||||||
no_captions_threshold=no_caption_threshold,
|
|
||||||
**args,
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_basename = os.path.basename(audio_path)
|
audio_basename = os.path.basename(audio_path)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user