diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index 6fcf50e..0a51d49 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -307,7 +307,14 @@ def _layer_norm_fwd( or dropout_p > 0.0 or rowscale is not None or x1 is not None - ): + ) and ( + x.requires_grad + or weight.requires_grad + or (bias is not None and bias.requires_grad) + or (residual is not None and residual.requires_grad) + or (x1 is not None and x1.requires_grad) + or (weight1 is not None and weight1.requires_grad) + or (bias1 is not None and bias1.requires_grad)): residual_out = torch.empty( M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype )