From 9c8183a1790efa6f8c7e47fc9b73131c4b66e153 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 26 Sep 2022 13:54:26 -0400 Subject: [PATCH] Use PyTorch as logits transpose for ONNX support (#141) --- whisper/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/model.py b/whisper/model.py index 2221570..1b5890f 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -189,7 +189,7 @@ class TextDecoder(nn.Module): x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = self.ln(x) - logits = (x @ self.token_embedding.weight.to(x.dtype).T).float() + logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() return logits