[LayerNorm] Update layer_norm_linear

This commit is contained in:
Tri Dao 2024-03-18 23:15:33 -07:00
parent 6bbc532388
commit 36587c01cb

View File

@ -987,7 +987,7 @@ class LayerNormLinearFn(torch.autograd.Function):
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
x,
norm_weight,
norm_bias,
@ -1031,7 +1031,7 @@ class LayerNormLinearFn(torch.autograd.Function):
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
dy,
x,
norm_weight,
@ -1039,9 +1039,9 @@ class LayerNormLinearFn(torch.autograd.Function):
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
dresidual=dresidual,
has_residual=ctx.has_residual,
is_rms_norm=ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)