diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index 0a51d49..6fcf50e 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -307,14 +307,7 @@ def _layer_norm_fwd( or dropout_p > 0.0 or rowscale is not None or x1 is not None - ) and ( - x.requires_grad - or weight.requires_grad - or (bias is not None and bias.requires_grad) - or (residual is not None and residual.requires_grad) - or (x1 is not None and x1.requires_grad) - or (weight1 is not None and weight1.requires_grad) - or (bias1 is not None and bias1.requires_grad)): + ): residual_out = torch.empty( M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype )