Fix a bug

This commit is contained in:
ljss 2023-06-02 13:46:19 +08:00 committed by GitHub
parent 85b51d61ee
commit 8e44c0eefb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,7 +119,7 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm)
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm)
if self.prenorm:
if not self.fused_dropout_add_ln: