diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 0293623..111641a 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -46,7 +46,7 @@ except ImportError: try: from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm except ImportError: - RMSNorm, dropout_add_rms_norm = None + RMSNorm, dropout_add_rms_norm = None, None try: from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual