diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index 7ea2c2f..b1dcdd2 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -26,7 +26,7 @@ except ImportError: try: from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm except ImportError: - RMSNorm, dropout_add_rms_norm = None + RMSNorm, dropout_add_rms_norm = None, None try: from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual