Fix a bug

This commit is contained in:
ljss 2023-06-02 13:46:19 +08:00 committed by GitHub
parent 85b51d61ee
commit 8e44c0eefb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,7 +119,7 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer. about the CLS token in the last layer.
""" """
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm) fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm) else dropout_add_layer_norm)
if self.prenorm: if self.prenorm:
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln: