Merge pull request #255 from beginlner/main

Fix a bug
This commit is contained in:
Tri Dao 2023-06-02 02:23:25 -04:00 committed by GitHub
commit 9818f85fee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,7 +119,7 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm)
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm)
if self.prenorm:
if not self.fused_dropout_add_ln: