[LayerNorm] Add option to write result to out and residual_out

This commit is contained in:
Tri Dao 2024-08-15 14:43:47 -07:00
parent bd82d6c6eb
commit bcd918f275

View File

@ -267,6 +267,8 @@ def _layer_norm_fwd(
residual_dtype=None,
is_rms_norm=False,
return_dropout_mask=False,
out=None,
residual_out=None
):
if residual is not None:
residual_dtype = residual.dtype
@ -294,10 +296,13 @@ def _layer_norm_fwd(
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
if out is None:
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
else:
assert out.shape == x.shape
assert out.stride(-1) == 1
if weight1 is not None:
y1 = torch.empty_like(y)
y1 = torch.empty_like(out)
assert y1.stride(-1) == 1
else:
y1 = None
@ -308,9 +313,12 @@ def _layer_norm_fwd(
or rowscale is not None
or x1 is not None
):
residual_out = torch.empty(
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
)
if residual_out is None:
residual_out = torch.empty(
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
)
else:
assert residual_out.shape == x.shape
assert residual_out.stride(-1) == 1
else:
residual_out = None
@ -334,7 +342,7 @@ def _layer_norm_fwd(
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
y,
out,
weight,
bias,
residual,
@ -349,7 +357,7 @@ def _layer_norm_fwd(
mean,
rstd,
x.stride(0),
y.stride(0),
out.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0,
@ -373,7 +381,7 @@ def _layer_norm_fwd(
else:
dropout_mask1 = None
return (
y,
out,
y1,
mean,
rstd,
@ -714,6 +722,8 @@ class LayerNormFn(torch.autograd.Function):
residual_in_fp32=False,
is_rms_norm=False,
return_dropout_mask=False,
out=None,
residual_out=None
):
x_shape_og = x.shape
# reshape input data into 2D tensor
@ -745,6 +755,10 @@ class LayerNormFn(torch.autograd.Function):
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
if out is not None:
out = out.reshape(-1, out.shape[-1])
if residual_out is not None:
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
x,
weight,
@ -759,6 +773,8 @@ class LayerNormFn(torch.autograd.Function):
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
out=out,
residual_out=residual_out
)
ctx.save_for_backward(
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
@ -853,6 +869,8 @@ class LayerNormFn(torch.autograd.Function):
None,
None,
None,
None,
None,
)
@ -871,6 +889,8 @@ def layer_norm_fn(
residual_in_fp32=False,
is_rms_norm=False,
return_dropout_mask=False,
out=None,
residual_out=None
):
return LayerNormFn.apply(
x,
@ -887,6 +907,8 @@ def layer_norm_fn(
residual_in_fp32,
is_rms_norm,
return_dropout_mask,
out,
residual_out
)
@ -904,6 +926,8 @@ def rms_norm_fn(
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
out=None,
residual_out=None
):
return LayerNormFn.apply(
x,
@ -920,6 +944,8 @@ def rms_norm_fn(
residual_in_fp32,
True,
return_dropout_mask,
out,
residual_out
)