[LayerNorm] Add option to write result to out and residual_out
This commit is contained in:
parent
bd82d6c6eb
commit
bcd918f275
@ -267,6 +267,8 @@ def _layer_norm_fwd(
|
|||||||
residual_dtype=None,
|
residual_dtype=None,
|
||||||
is_rms_norm=False,
|
is_rms_norm=False,
|
||||||
return_dropout_mask=False,
|
return_dropout_mask=False,
|
||||||
|
out=None,
|
||||||
|
residual_out=None
|
||||||
):
|
):
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
residual_dtype = residual.dtype
|
residual_dtype = residual.dtype
|
||||||
@ -294,10 +296,13 @@ def _layer_norm_fwd(
|
|||||||
assert rowscale.is_contiguous()
|
assert rowscale.is_contiguous()
|
||||||
assert rowscale.shape == (M,)
|
assert rowscale.shape == (M,)
|
||||||
# allocate output
|
# allocate output
|
||||||
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
if out is None:
|
||||||
assert y.stride(-1) == 1
|
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
||||||
|
else:
|
||||||
|
assert out.shape == x.shape
|
||||||
|
assert out.stride(-1) == 1
|
||||||
if weight1 is not None:
|
if weight1 is not None:
|
||||||
y1 = torch.empty_like(y)
|
y1 = torch.empty_like(out)
|
||||||
assert y1.stride(-1) == 1
|
assert y1.stride(-1) == 1
|
||||||
else:
|
else:
|
||||||
y1 = None
|
y1 = None
|
||||||
@ -308,9 +313,12 @@ def _layer_norm_fwd(
|
|||||||
or rowscale is not None
|
or rowscale is not None
|
||||||
or x1 is not None
|
or x1 is not None
|
||||||
):
|
):
|
||||||
residual_out = torch.empty(
|
if residual_out is None:
|
||||||
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
residual_out = torch.empty(
|
||||||
)
|
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert residual_out.shape == x.shape
|
||||||
assert residual_out.stride(-1) == 1
|
assert residual_out.stride(-1) == 1
|
||||||
else:
|
else:
|
||||||
residual_out = None
|
residual_out = None
|
||||||
@ -334,7 +342,7 @@ def _layer_norm_fwd(
|
|||||||
with torch.cuda.device(x.device.index):
|
with torch.cuda.device(x.device.index):
|
||||||
_layer_norm_fwd_1pass_kernel[(M,)](
|
_layer_norm_fwd_1pass_kernel[(M,)](
|
||||||
x,
|
x,
|
||||||
y,
|
out,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
residual,
|
residual,
|
||||||
@ -349,7 +357,7 @@ def _layer_norm_fwd(
|
|||||||
mean,
|
mean,
|
||||||
rstd,
|
rstd,
|
||||||
x.stride(0),
|
x.stride(0),
|
||||||
y.stride(0),
|
out.stride(0),
|
||||||
residual.stride(0) if residual is not None else 0,
|
residual.stride(0) if residual is not None else 0,
|
||||||
residual_out.stride(0) if residual_out is not None else 0,
|
residual_out.stride(0) if residual_out is not None else 0,
|
||||||
x1.stride(0) if x1 is not None else 0,
|
x1.stride(0) if x1 is not None else 0,
|
||||||
@ -373,7 +381,7 @@ def _layer_norm_fwd(
|
|||||||
else:
|
else:
|
||||||
dropout_mask1 = None
|
dropout_mask1 = None
|
||||||
return (
|
return (
|
||||||
y,
|
out,
|
||||||
y1,
|
y1,
|
||||||
mean,
|
mean,
|
||||||
rstd,
|
rstd,
|
||||||
@ -714,6 +722,8 @@ class LayerNormFn(torch.autograd.Function):
|
|||||||
residual_in_fp32=False,
|
residual_in_fp32=False,
|
||||||
is_rms_norm=False,
|
is_rms_norm=False,
|
||||||
return_dropout_mask=False,
|
return_dropout_mask=False,
|
||||||
|
out=None,
|
||||||
|
residual_out=None
|
||||||
):
|
):
|
||||||
x_shape_og = x.shape
|
x_shape_og = x.shape
|
||||||
# reshape input data into 2D tensor
|
# reshape input data into 2D tensor
|
||||||
@ -745,6 +755,10 @@ class LayerNormFn(torch.autograd.Function):
|
|||||||
if residual is not None
|
if residual is not None
|
||||||
else (torch.float32 if residual_in_fp32 else None)
|
else (torch.float32 if residual_in_fp32 else None)
|
||||||
)
|
)
|
||||||
|
if out is not None:
|
||||||
|
out = out.reshape(-1, out.shape[-1])
|
||||||
|
if residual_out is not None:
|
||||||
|
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
|
||||||
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
|
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
|
||||||
x,
|
x,
|
||||||
weight,
|
weight,
|
||||||
@ -759,6 +773,8 @@ class LayerNormFn(torch.autograd.Function):
|
|||||||
residual_dtype=residual_dtype,
|
residual_dtype=residual_dtype,
|
||||||
is_rms_norm=is_rms_norm,
|
is_rms_norm=is_rms_norm,
|
||||||
return_dropout_mask=return_dropout_mask,
|
return_dropout_mask=return_dropout_mask,
|
||||||
|
out=out,
|
||||||
|
residual_out=residual_out
|
||||||
)
|
)
|
||||||
ctx.save_for_backward(
|
ctx.save_for_backward(
|
||||||
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
||||||
@ -853,6 +869,8 @@ class LayerNormFn(torch.autograd.Function):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -871,6 +889,8 @@ def layer_norm_fn(
|
|||||||
residual_in_fp32=False,
|
residual_in_fp32=False,
|
||||||
is_rms_norm=False,
|
is_rms_norm=False,
|
||||||
return_dropout_mask=False,
|
return_dropout_mask=False,
|
||||||
|
out=None,
|
||||||
|
residual_out=None
|
||||||
):
|
):
|
||||||
return LayerNormFn.apply(
|
return LayerNormFn.apply(
|
||||||
x,
|
x,
|
||||||
@ -887,6 +907,8 @@ def layer_norm_fn(
|
|||||||
residual_in_fp32,
|
residual_in_fp32,
|
||||||
is_rms_norm,
|
is_rms_norm,
|
||||||
return_dropout_mask,
|
return_dropout_mask,
|
||||||
|
out,
|
||||||
|
residual_out
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -904,6 +926,8 @@ def rms_norm_fn(
|
|||||||
prenorm=False,
|
prenorm=False,
|
||||||
residual_in_fp32=False,
|
residual_in_fp32=False,
|
||||||
return_dropout_mask=False,
|
return_dropout_mask=False,
|
||||||
|
out=None,
|
||||||
|
residual_out=None
|
||||||
):
|
):
|
||||||
return LayerNormFn.apply(
|
return LayerNormFn.apply(
|
||||||
x,
|
x,
|
||||||
@ -920,6 +944,8 @@ def rms_norm_fn(
|
|||||||
residual_in_fp32,
|
residual_in_fp32,
|
||||||
True,
|
True,
|
||||||
return_dropout_mask,
|
return_dropout_mask,
|
||||||
|
out,
|
||||||
|
residual_out
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user