[LayerNorm] Add option to write result to out and residual_out
This commit is contained in:
parent
bd82d6c6eb
commit
bcd918f275
@ -267,6 +267,8 @@ def _layer_norm_fwd(
|
||||
residual_dtype=None,
|
||||
is_rms_norm=False,
|
||||
return_dropout_mask=False,
|
||||
out=None,
|
||||
residual_out=None
|
||||
):
|
||||
if residual is not None:
|
||||
residual_dtype = residual.dtype
|
||||
@ -294,10 +296,13 @@ def _layer_norm_fwd(
|
||||
assert rowscale.is_contiguous()
|
||||
assert rowscale.shape == (M,)
|
||||
# allocate output
|
||||
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
if out is None:
|
||||
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
||||
else:
|
||||
assert out.shape == x.shape
|
||||
assert out.stride(-1) == 1
|
||||
if weight1 is not None:
|
||||
y1 = torch.empty_like(y)
|
||||
y1 = torch.empty_like(out)
|
||||
assert y1.stride(-1) == 1
|
||||
else:
|
||||
y1 = None
|
||||
@ -308,9 +313,12 @@ def _layer_norm_fwd(
|
||||
or rowscale is not None
|
||||
or x1 is not None
|
||||
):
|
||||
residual_out = torch.empty(
|
||||
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
||||
)
|
||||
if residual_out is None:
|
||||
residual_out = torch.empty(
|
||||
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
||||
)
|
||||
else:
|
||||
assert residual_out.shape == x.shape
|
||||
assert residual_out.stride(-1) == 1
|
||||
else:
|
||||
residual_out = None
|
||||
@ -334,7 +342,7 @@ def _layer_norm_fwd(
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[(M,)](
|
||||
x,
|
||||
y,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
residual,
|
||||
@ -349,7 +357,7 @@ def _layer_norm_fwd(
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
y.stride(0),
|
||||
out.stride(0),
|
||||
residual.stride(0) if residual is not None else 0,
|
||||
residual_out.stride(0) if residual_out is not None else 0,
|
||||
x1.stride(0) if x1 is not None else 0,
|
||||
@ -373,7 +381,7 @@ def _layer_norm_fwd(
|
||||
else:
|
||||
dropout_mask1 = None
|
||||
return (
|
||||
y,
|
||||
out,
|
||||
y1,
|
||||
mean,
|
||||
rstd,
|
||||
@ -714,6 +722,8 @@ class LayerNormFn(torch.autograd.Function):
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
return_dropout_mask=False,
|
||||
out=None,
|
||||
residual_out=None
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
@ -745,6 +755,10 @@ class LayerNormFn(torch.autograd.Function):
|
||||
if residual is not None
|
||||
else (torch.float32 if residual_in_fp32 else None)
|
||||
)
|
||||
if out is not None:
|
||||
out = out.reshape(-1, out.shape[-1])
|
||||
if residual_out is not None:
|
||||
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
|
||||
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
@ -759,6 +773,8 @@ class LayerNormFn(torch.autograd.Function):
|
||||
residual_dtype=residual_dtype,
|
||||
is_rms_norm=is_rms_norm,
|
||||
return_dropout_mask=return_dropout_mask,
|
||||
out=out,
|
||||
residual_out=residual_out
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
||||
@ -853,6 +869,8 @@ class LayerNormFn(torch.autograd.Function):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@ -871,6 +889,8 @@ def layer_norm_fn(
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
return_dropout_mask=False,
|
||||
out=None,
|
||||
residual_out=None
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x,
|
||||
@ -887,6 +907,8 @@ def layer_norm_fn(
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
return_dropout_mask,
|
||||
out,
|
||||
residual_out
|
||||
)
|
||||
|
||||
|
||||
@ -904,6 +926,8 @@ def rms_norm_fn(
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
return_dropout_mask=False,
|
||||
out=None,
|
||||
residual_out=None
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x,
|
||||
@ -920,6 +944,8 @@ def rms_norm_fn(
|
||||
residual_in_fp32,
|
||||
True,
|
||||
return_dropout_mask,
|
||||
out,
|
||||
residual_out
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user