From 36587c01cb4390de0a590b2131e3fcc4859ba09c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 18 Mar 2024 23:15:33 -0700 Subject: [PATCH] [LayerNorm] Update layer_norm_linear --- flash_attn/ops/triton/layer_norm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index c922906..6fcf50e 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -987,7 +987,7 @@ class LayerNormLinearFn(torch.autograd.Function): if residual is not None else (torch.float32 if residual_in_fp32 else None) ) - y, mean, rstd, residual_out = _layer_norm_fwd( + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( x, norm_weight, norm_bias, @@ -1031,7 +1031,7 @@ class LayerNormLinearFn(torch.autograd.Function): assert dresidual.shape == x.shape else: dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( dy, x, norm_weight, @@ -1039,9 +1039,9 @@ class LayerNormLinearFn(torch.autograd.Function): ctx.eps, mean, rstd, - dresidual, - ctx.has_residual, - ctx.is_rms_norm, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, x_dtype=ctx.x_dtype, recompute_output=True, )