[Bugfix] Fix RMSNorm forward in InternViT attention qk_layernorm (#6992)
This commit is contained in:
parent
7e0861bd0b
commit
2dd34371a6
@ -113,10 +113,10 @@ class InternAttention(nn.Module):
|
||||
|
||||
if self.qk_normalization:
|
||||
B_, H_, N_, D_ = q.shape
|
||||
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(
|
||||
B_, N_, H_, D_).transpose(1, 2)
|
||||
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(
|
||||
B_, N_, H_, D_).transpose(1, 2)
|
||||
q = self.q_norm.forward_native(q.transpose(1, 2).flatten(
|
||||
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||
k = self.k_norm.forward_native(k.transpose(1, 2).flatten(
|
||||
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user