diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index c922906..6fcf50e 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -987,7 +987,7 @@ class LayerNormLinearFn(torch.autograd.Function): if residual is not None else (torch.float32 if residual_in_fp32 else None) ) - y, mean, rstd, residual_out = _layer_norm_fwd( + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( x, norm_weight, norm_bias, @@ -1031,7 +1031,7 @@ class LayerNormLinearFn(torch.autograd.Function): assert dresidual.shape == x.shape else: dresidual = None - dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( dy, x, norm_weight, @@ -1039,9 +1039,9 @@ class LayerNormLinearFn(torch.autograd.Function): ctx.eps, mean, rstd, - dresidual, - ctx.has_residual, - ctx.is_rms_norm, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, x_dtype=ctx.x_dtype, recompute_output=True, )