diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index b1dcdd2..a4ff5a2 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -119,7 +119,7 @@ class Block(nn.Module): before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer. """ - fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm) + fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm) else dropout_add_layer_norm) if self.prenorm: if not self.fused_dropout_add_ln: