nocaptions -> nospeech to match the paper figure

This commit is contained in:
Jong Wook Kim
2022-09-23 15:45:32 +09:00
parent 61989529b7
commit 15ab548263
3 changed files with 27 additions and 39 deletions

View File

@@ -108,7 +108,7 @@ class DecodingResult:
tokens: List[int] = field(default_factory=list) tokens: List[int] = field(default_factory=list)
text: str = "" text: str = ""
avg_logprob: float = np.nan avg_logprob: float = np.nan
no_caption_prob: float = np.nan no_speech_prob: float = np.nan
temperature: float = np.nan temperature: float = np.nan
compression_ratio: float = np.nan compression_ratio: float = np.nan
@@ -543,9 +543,9 @@ class DecodingTask:
suppress_tokens.extend( suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm] [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
) )
if self.tokenizer.no_captions is not None: if self.tokenizer.no_speech is not None:
# no-captions probability is collected separately # no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_captions) suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens))) return tuple(sorted(set(suppress_tokens)))
@@ -580,15 +580,15 @@ class DecodingTask:
assert audio_features.shape[0] == tokens.shape[0] assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0] n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_caption_probs = [np.nan] * n_batch no_speech_probs = [np.nan] * n_batch
try: try:
for i in range(self.sample_len): for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features) logits = self.inference.logits(tokens, audio_features)
if i == 0 and self.tokenizer.no_captions is not None: # save no_caption_probs if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_caption_probs = probs_at_sot[:, self.tokenizer.no_captions].tolist() no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only # now we need to consider the logits at the last token only
logits = logits[:, -1] logits = logits[:, -1]
@@ -605,7 +605,7 @@ class DecodingTask:
finally: finally:
self.inference.cleanup_caching() self.inference.cleanup_caching()
return tokens, sum_logprobs, no_caption_probs return tokens, sum_logprobs, no_speech_probs
@torch.no_grad() @torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]: def run(self, mel: Tensor) -> List[DecodingResult]:
@@ -629,12 +629,12 @@ class DecodingTask:
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop # call the main sampling loop
tokens, sum_logprobs, no_caption_probs = self._main_loop(audio_features, tokens) tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions # reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group] audio_features = audio_features[:: self.n_group]
no_caption_probs = no_caption_probs[:: self.n_group] no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_caption_probs) == n_audio assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1) tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
@@ -653,7 +653,7 @@ class DecodingTask:
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_caption_probs) fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
if len(set(map(len, fields))) != 1: if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
@@ -664,11 +664,11 @@ class DecodingTask:
tokens=tokens, tokens=tokens,
text=text, text=text,
avg_logprob=avg_logprob, avg_logprob=avg_logprob,
no_caption_prob=no_caption_prob, no_speech_prob=no_speech_prob,
temperature=self.options.temperature, temperature=self.options.temperature,
compression_ratio=compression_ratio(text), compression_ratio=compression_ratio(text),
) )
for text, language, tokens, features, avg_logprob, no_caption_prob in zip(*fields) for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
] ]

View File

@@ -178,8 +178,8 @@ class Tokenizer:
@property @property
@lru_cache() @lru_cache()
def no_captions(self) -> int: def no_speech(self) -> int:
return self._get_single_token_id("<|nocaptions|>") return self._get_single_token_id("<|nospeech|>")
@property @property
@lru_cache() @lru_cache()
@@ -283,7 +283,7 @@ def build_tokenizer(name: str = "gpt2"):
"<|transcribe|>", "<|transcribe|>",
"<|startoflm|>", "<|startoflm|>",
"<|startofprev|>", "<|startofprev|>",
"<|nocaptions|>", "<|nospeech|>",
"<|notimestamps|>", "<|notimestamps|>",
] ]

View File

@@ -23,7 +23,7 @@ def transcribe(
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4, compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0, logprob_threshold: Optional[float] = -1.0,
no_captions_threshold: Optional[float] = 0.6, no_speech_threshold: Optional[float] = 0.6,
**decode_options, **decode_options,
): ):
""" """
@@ -50,8 +50,8 @@ def transcribe(
logprob_threshold: float logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed If the average log probability over sampled tokens is below this value, treat as failed
no_captions_threshold: float no_speech_threshold: float
If the no_captions probability is higher than this value AND the average log probability If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent over sampled tokens is below `logprob_threshold`, consider the segment as silent
decode_options: dict decode_options: dict
@@ -148,7 +148,7 @@ def transcribe(
"temperature": result.temperature, "temperature": result.temperature,
"avg_logprob": result.avg_logprob, "avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio, "compression_ratio": result.compression_ratio,
"no_caption_prob": result.no_caption_prob, "no_speech_prob": result.no_speech_prob,
} }
) )
if verbose: if verbose:
@@ -163,11 +163,11 @@ def transcribe(
result = decode_with_fallback(segment)[0] result = decode_with_fallback(segment)[0]
tokens = torch.tensor(result.tokens) tokens = torch.tensor(result.tokens)
if no_captions_threshold is not None: if no_speech_threshold is not None:
# no voice activity check # no voice activity check
should_skip = result.no_caption_prob > no_captions_threshold should_skip = result.no_speech_prob > no_speech_threshold
if logprob_threshold is not None and result.avg_logprob > logprob_threshold: if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_captions_prob # don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False should_skip = False
if should_skip: if should_skip:
@@ -249,7 +249,7 @@ def cli():
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_caption_threshold", type=optional_float, default=0.6, help="if the probability of the <|nocaptions|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
model_name: str = args.pop("model") model_name: str = args.pop("model")
@@ -261,12 +261,8 @@ def cli():
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en" args["language"] = "en"
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
compression_ratio_threshold = args.pop("compression_ratio_threshold")
logprob_threshold = args.pop("logprob_threshold")
no_caption_threshold = args.pop("no_caption_threshold")
temperature = args.pop("temperature") temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None: if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
else: else:
@@ -276,15 +272,7 @@ def cli():
model = load_model(model_name, device=device) model = load_model(model_name, device=device)
for audio_path in args.pop("audio"): for audio_path in args.pop("audio"):
result = transcribe( result = transcribe(model, audio_path, temperature=temperature, **args)
model,
audio_path,
temperature=temperature,
compression_ratio_threshold=compression_ratio_threshold,
logprob_threshold=logprob_threshold,
no_captions_threshold=no_caption_threshold,
**args,
)
audio_basename = os.path.basename(audio_path) audio_basename = os.path.basename(audio_path)