invoking __call__ instead of forward()
This commit is contained in:
@@ -214,10 +214,10 @@ class Whisper(nn.Module):
|
||||
)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder.forward(mel)
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder.forward(tokens, audio_features)
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
Reference in New Issue
Block a user