From bd82d6c6eb04f2a8faf34423cd15d538602d59f0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 15 Aug 2024 12:02:39 -0700 Subject: [PATCH] Revert "[LayerNorm] Don't store x + residual if we don't need gradients" This reverts commit 800401847e1b54c9346c80766fa9f31a71b52c7e. --- flash_attn/ops/triton/layer_norm.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index 0a51d49..6fcf50e 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -307,14 +307,7 @@ def _layer_norm_fwd( or dropout_p > 0.0 or rowscale is not None or x1 is not None - ) and ( - x.requires_grad - or weight.requires_grad - or (bias is not None and bias.requires_grad) - or (residual is not None and residual.requires_grad) - or (x1 is not None and x1.requires_grad) - or (weight1 is not None and weight1.requires_grad) - or (bias1 is not None and bias1.requires_grad)): + ): residual_out = torch.empty( M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype )