[Bugfix] Fix RMSNorm forward in InternViT attention qk_layernorm (#6992)

This commit is contained in:
Isotr0py 2024-08-02 03:00:28 +08:00 committed by GitHub
parent 7e0861bd0b
commit 2dd34371a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -113,10 +113,10 @@ class InternAttention(nn.Module):
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(
B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(
B_, N_, H_, D_).transpose(1, 2)
q = self.q_norm.forward_native(q.transpose(1, 2).flatten(
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm.forward_native(k.transpose(1, 2).flatten(
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, C)