MultiHeadAttention to return qk as well
This commit is contained in:
@@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
k = kv_cache[self.key]
|
k = kv_cache[self.key]
|
||||||
v = kv_cache[self.value]
|
v = kv_cache[self.value]
|
||||||
|
|
||||||
wv = self.qkv_attention(q, k, v, mask)
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
return self.out(wv)
|
return self.out(wv), qk
|
||||||
|
|
||||||
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||||
n_batch, n_ctx, n_state = q.shape
|
n_batch, n_ctx, n_state = q.shape
|
||||||
@@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module):
|
|||||||
qk = q @ k
|
qk = q @ k
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
qk = qk + mask[:n_ctx, :n_ctx]
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
|
qk = qk.float()
|
||||||
|
|
||||||
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
@@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[dict] = None,
|
kv_cache: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||||
if self.cross_attn:
|
if self.cross_attn:
|
||||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||||
x = x + self.mlp(self.mlp_ln(x))
|
x = x + self.mlp(self.mlp_ln(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user