[LayerNorm] Add option to write result to out and residual_out

This commit is contained in:
Tri Dao 2024-08-15 14:43:47 -07:00
parent bd82d6c6eb
commit bcd918f275

View File

@ -267,6 +267,8 @@ def _layer_norm_fwd(
residual_dtype=None, residual_dtype=None,
is_rms_norm=False, is_rms_norm=False,
return_dropout_mask=False, return_dropout_mask=False,
out=None,
residual_out=None
): ):
if residual is not None: if residual is not None:
residual_dtype = residual.dtype residual_dtype = residual.dtype
@ -294,10 +296,13 @@ def _layer_norm_fwd(
assert rowscale.is_contiguous() assert rowscale.is_contiguous()
assert rowscale.shape == (M,) assert rowscale.shape == (M,)
# allocate output # allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) if out is None:
assert y.stride(-1) == 1 out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
else:
assert out.shape == x.shape
assert out.stride(-1) == 1
if weight1 is not None: if weight1 is not None:
y1 = torch.empty_like(y) y1 = torch.empty_like(out)
assert y1.stride(-1) == 1 assert y1.stride(-1) == 1
else: else:
y1 = None y1 = None
@ -308,9 +313,12 @@ def _layer_norm_fwd(
or rowscale is not None or rowscale is not None
or x1 is not None or x1 is not None
): ):
residual_out = torch.empty( if residual_out is None:
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype residual_out = torch.empty(
) M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
)
else:
assert residual_out.shape == x.shape
assert residual_out.stride(-1) == 1 assert residual_out.stride(-1) == 1
else: else:
residual_out = None residual_out = None
@ -334,7 +342,7 @@ def _layer_norm_fwd(
with torch.cuda.device(x.device.index): with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)]( _layer_norm_fwd_1pass_kernel[(M,)](
x, x,
y, out,
weight, weight,
bias, bias,
residual, residual,
@ -349,7 +357,7 @@ def _layer_norm_fwd(
mean, mean,
rstd, rstd,
x.stride(0), x.stride(0),
y.stride(0), out.stride(0),
residual.stride(0) if residual is not None else 0, residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0, residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0, x1.stride(0) if x1 is not None else 0,
@ -373,7 +381,7 @@ def _layer_norm_fwd(
else: else:
dropout_mask1 = None dropout_mask1 = None
return ( return (
y, out,
y1, y1,
mean, mean,
rstd, rstd,
@ -714,6 +722,8 @@ class LayerNormFn(torch.autograd.Function):
residual_in_fp32=False, residual_in_fp32=False,
is_rms_norm=False, is_rms_norm=False,
return_dropout_mask=False, return_dropout_mask=False,
out=None,
residual_out=None
): ):
x_shape_og = x.shape x_shape_og = x.shape
# reshape input data into 2D tensor # reshape input data into 2D tensor
@ -745,6 +755,10 @@ class LayerNormFn(torch.autograd.Function):
if residual is not None if residual is not None
else (torch.float32 if residual_in_fp32 else None) else (torch.float32 if residual_in_fp32 else None)
) )
if out is not None:
out = out.reshape(-1, out.shape[-1])
if residual_out is not None:
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
x, x,
weight, weight,
@ -759,6 +773,8 @@ class LayerNormFn(torch.autograd.Function):
residual_dtype=residual_dtype, residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm, is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask, return_dropout_mask=return_dropout_mask,
out=out,
residual_out=residual_out
) )
ctx.save_for_backward( ctx.save_for_backward(
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
@ -853,6 +869,8 @@ class LayerNormFn(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
@ -871,6 +889,8 @@ def layer_norm_fn(
residual_in_fp32=False, residual_in_fp32=False,
is_rms_norm=False, is_rms_norm=False,
return_dropout_mask=False, return_dropout_mask=False,
out=None,
residual_out=None
): ):
return LayerNormFn.apply( return LayerNormFn.apply(
x, x,
@ -887,6 +907,8 @@ def layer_norm_fn(
residual_in_fp32, residual_in_fp32,
is_rms_norm, is_rms_norm,
return_dropout_mask, return_dropout_mask,
out,
residual_out
) )
@ -904,6 +926,8 @@ def rms_norm_fn(
prenorm=False, prenorm=False,
residual_in_fp32=False, residual_in_fp32=False,
return_dropout_mask=False, return_dropout_mask=False,
out=None,
residual_out=None
): ):
return LayerNormFn.apply( return LayerNormFn.apply(
x, x,
@ -920,6 +944,8 @@ def rms_norm_fn(
residual_in_fp32, residual_in_fp32,
True, True,
return_dropout_mask, return_dropout_mask,
out,
residual_out
) )