saving the qk matrix in the attention module for convenience
This commit is contained in:
@@ -62,6 +62,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.key = Linear(n_state, n_state, bias=False)
|
self.key = Linear(n_state, n_state, bias=False)
|
||||||
self.value = Linear(n_state, n_state)
|
self.value = Linear(n_state, n_state)
|
||||||
self.out = Linear(n_state, n_state)
|
self.out = Linear(n_state, n_state)
|
||||||
|
self.last_qk = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -96,6 +97,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
qk = qk + mask[:n_ctx, :n_ctx]
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
|
|
||||||
|
self.last_qk = qk.detach()
|
||||||
|
|
||||||
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
||||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user