invoking __call__ instead of forward()

This commit is contained in:
Jong Wook Kim
2022-11-16 04:18:50 -08:00
parent 02aa851a49
commit eff383b27b
2 changed files with 6 additions and 2 deletions

View File

@@ -214,10 +214,10 @@ class Whisper(nn.Module):
)
def embed_audio(self, mel: torch.Tensor):
return self.encoder.forward(mel)
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder.forward(tokens, audio_features)
return self.decoder(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))