801 lines
22 KiB
Python
801 lines
22 KiB
Python
# Copyright (c) 2022, Tri Dao.
|
|
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
|
|
|
|
import dropout_layer_norm
|
|
import torch
|
|
from torch.nn import init
|
|
|
|
|
|
def maybe_align(x, alignment_in_bytes=16):
|
|
"""Assume that x already has last dim divisible by alignment_in_bytes"""
|
|
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
|
|
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
|
|
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
|
|
|
|
|
|
def _dropout_add_layer_norm_forward(
|
|
x0,
|
|
residual,
|
|
gamma,
|
|
beta,
|
|
rowscale,
|
|
colscale,
|
|
dropout_p,
|
|
epsilon,
|
|
residual_in_fp32=False,
|
|
is_rms_norm=False,
|
|
):
|
|
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
|
hidden_size = gamma.numel()
|
|
x0mat = x0.view((-1, hidden_size))
|
|
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
|
rowscale = rowscale.view(-1) if rowscale is not None else None
|
|
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
|
x0mat,
|
|
residualmat,
|
|
gamma,
|
|
beta,
|
|
rowscale,
|
|
colscale,
|
|
None,
|
|
None,
|
|
dropout_p,
|
|
epsilon,
|
|
1.0,
|
|
0,
|
|
None,
|
|
residual_in_fp32,
|
|
is_rms_norm,
|
|
)
|
|
# dmask is None if dropout_p == 0.0
|
|
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
|
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
|
|
|
|
|
def _dropout_add_layer_norm_backward(
|
|
dz,
|
|
dx,
|
|
x,
|
|
x0,
|
|
dmask,
|
|
mu,
|
|
rsigma,
|
|
gamma,
|
|
rowscale,
|
|
colscale,
|
|
dropout_p,
|
|
has_residual,
|
|
is_rms_norm=False,
|
|
):
|
|
"""Assume that arguments are contiguous and aligned to 16 bytes
|
|
dx == None means that it was a post-norm architecture
|
|
(x = drop(x0) + residual was not returned in the fwd).
|
|
x0 must not be None if we have colscale.
|
|
"""
|
|
hidden_size = gamma.numel()
|
|
xmat = x.view((-1, hidden_size))
|
|
dzmat = dz.view(xmat.shape)
|
|
dxmat = dx.view(xmat.shape) if dx is not None else None
|
|
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
|
rowscale = rowscale.view(-1) if rowscale is not None else None
|
|
if colscale is not None:
|
|
assert x0 is not None, "x0 is required to compute the gradient of colscale"
|
|
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
|
dzmat,
|
|
dxmat,
|
|
xmat,
|
|
x0mat,
|
|
dmask,
|
|
mu,
|
|
rsigma,
|
|
gamma,
|
|
rowscale,
|
|
colscale,
|
|
None,
|
|
None,
|
|
dropout_p,
|
|
1.0,
|
|
0,
|
|
has_residual,
|
|
is_rms_norm,
|
|
)
|
|
# dresidualmat is None if not has_residual
|
|
if colscale is None:
|
|
return dx0mat, dresidualmat, dgamma, dbeta
|
|
else:
|
|
dcolscale = rest[0]
|
|
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
|
|
|
|
|
def _dropout_add_layer_norm_subset_forward(
|
|
x0,
|
|
residual,
|
|
gamma,
|
|
beta,
|
|
colscale,
|
|
x0_subset,
|
|
out_subset,
|
|
dropout_p,
|
|
epsilon,
|
|
rowscale_const,
|
|
out_numrows,
|
|
residual_in_fp32=False,
|
|
is_rms_norm=False,
|
|
):
|
|
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
|
hidden_size = gamma.numel()
|
|
x0mat = x0.view((-1, hidden_size))
|
|
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
|
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
|
out_subset = out_subset.view(-1) if out_subset is not None else None
|
|
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
|
x0mat,
|
|
residualmat,
|
|
gamma,
|
|
beta,
|
|
None,
|
|
colscale,
|
|
x0_subset,
|
|
out_subset,
|
|
dropout_p,
|
|
epsilon,
|
|
rowscale_const,
|
|
out_numrows,
|
|
None,
|
|
residual_in_fp32,
|
|
is_rms_norm,
|
|
)
|
|
# dmask is None if dropout_p == 0.0
|
|
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
|
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
|
|
|
|
|
def _dropout_add_layer_norm_subset_backward(
|
|
dz,
|
|
dx,
|
|
x,
|
|
x0,
|
|
dmask,
|
|
mu,
|
|
rsigma,
|
|
gamma,
|
|
colscale,
|
|
x0_subset,
|
|
out_subset,
|
|
dropout_p,
|
|
rowscale_const,
|
|
x0_numrows,
|
|
has_residual,
|
|
is_rms_norm=False,
|
|
):
|
|
"""Assume that arguments are contiguous and aligned to 16 bytes
|
|
dx == None means that it was a post-norm architecture
|
|
(x = drop(x0) + residual was not returned in the fwd).
|
|
x0 must not be None if we have colscale.
|
|
"""
|
|
hidden_size = gamma.numel()
|
|
xmat = x.view((-1, hidden_size))
|
|
dzmat = dz.view(-1, hidden_size)
|
|
dxmat = dx.view(xmat.shape) if dx is not None else None
|
|
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
|
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
|
out_subset = out_subset.view(-1) if out_subset is not None else None
|
|
if colscale is not None:
|
|
assert x0 is not None, "x0 is required to compute the gradient of colscale"
|
|
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
|
dzmat,
|
|
dxmat,
|
|
xmat,
|
|
x0mat,
|
|
dmask,
|
|
mu,
|
|
rsigma,
|
|
gamma,
|
|
None,
|
|
colscale,
|
|
x0_subset,
|
|
out_subset,
|
|
dropout_p,
|
|
rowscale_const,
|
|
x0_numrows,
|
|
has_residual,
|
|
is_rms_norm,
|
|
)
|
|
# dresidualmat is None if not has_residual
|
|
if colscale is None:
|
|
return dx0mat, dresidualmat, dgamma, dbeta
|
|
else:
|
|
dcolscale = rest[0]
|
|
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
|
|
|
|
|
def _dropout_add_layer_norm_parallel_residual_forward(
|
|
x0,
|
|
x1,
|
|
residual,
|
|
gamma0,
|
|
beta0,
|
|
gamma1,
|
|
beta1,
|
|
dropout_p,
|
|
epsilon,
|
|
residual_in_fp32=False,
|
|
is_rms_norm=False,
|
|
):
|
|
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
|
hidden_size = gamma0.numel()
|
|
x0mat = x0.view((-1, hidden_size))
|
|
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
|
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
|
(
|
|
z0mat,
|
|
z1mat,
|
|
xmat,
|
|
dmask0,
|
|
dmask1,
|
|
mu,
|
|
rsigma,
|
|
) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
|
|
x0mat,
|
|
x1mat,
|
|
residualmat,
|
|
gamma0,
|
|
beta0,
|
|
gamma1,
|
|
beta1,
|
|
dropout_p,
|
|
epsilon,
|
|
None,
|
|
residual_in_fp32,
|
|
is_rms_norm,
|
|
)
|
|
# dmask0 and dmask1 are None if dropout_p == 0.0
|
|
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
|
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
|
|
|
|
|
|
def _dropout_add_layer_norm_parallel_residual_backward(
|
|
dz0,
|
|
dz1,
|
|
dx,
|
|
x,
|
|
dmask0,
|
|
dmask1,
|
|
mu,
|
|
rsigma,
|
|
gamma0,
|
|
gamma1,
|
|
dropout_p,
|
|
has_x1,
|
|
has_residual,
|
|
is_rms_norm=False,
|
|
):
|
|
"""Assume that arguments are contiguous and aligned to 16 bytes
|
|
dx == None means that it was a post-norm architecture
|
|
(x = drop(x0) + residual was not returned in the fwd).
|
|
"""
|
|
hidden_size = gamma0.numel()
|
|
xmat = x.view((-1, hidden_size))
|
|
dz0mat = dz0.view(xmat.shape)
|
|
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
|
|
dxmat = dx.view(xmat.shape) if dx is not None else None
|
|
(
|
|
dx0mat,
|
|
dx1mat,
|
|
dresidualmat,
|
|
dgamma0,
|
|
dbeta0,
|
|
dgamma1,
|
|
dbeta1,
|
|
*rest,
|
|
) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
|
|
dz0mat,
|
|
dz1mat,
|
|
dxmat,
|
|
xmat,
|
|
dmask0,
|
|
dmask1,
|
|
mu,
|
|
rsigma,
|
|
gamma0,
|
|
gamma1,
|
|
dropout_p,
|
|
has_x1,
|
|
has_residual,
|
|
is_rms_norm,
|
|
)
|
|
# dresidualmat is None if not has_residual
|
|
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
|
|
|
|
|
|
class DropoutAddLayerNormFn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
x0,
|
|
residual,
|
|
gamma,
|
|
beta,
|
|
rowscale,
|
|
colscale,
|
|
dropout_p,
|
|
epsilon,
|
|
residual_in_fp32=False,
|
|
prenorm=False,
|
|
is_rms_norm=False,
|
|
return_dmask=False,
|
|
):
|
|
x0 = maybe_align(x0.contiguous(), 16)
|
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
|
gamma = maybe_align(gamma.contiguous(), 16)
|
|
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
|
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
|
|
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
|
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
|
x0,
|
|
residual,
|
|
gamma,
|
|
beta,
|
|
rowscale,
|
|
colscale,
|
|
dropout_p,
|
|
epsilon,
|
|
residual_in_fp32,
|
|
is_rms_norm,
|
|
)
|
|
# Only need to save x0 if we need to compute gradient wrt colscale
|
|
x0_saved = x0 if colscale is not None else None
|
|
ctx.save_for_backward(
|
|
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
|
|
)
|
|
ctx.prenorm = prenorm
|
|
ctx.dropout_p = dropout_p
|
|
ctx.has_residual = residual is not None
|
|
ctx.is_rms_norm = is_rms_norm
|
|
ctx.has_beta = beta is not None
|
|
if not return_dmask:
|
|
return (
|
|
zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
|
|
)
|
|
else:
|
|
dmask = (
|
|
dmask.view(x0.shape)
|
|
if dropout_p > 0.0
|
|
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
|
)
|
|
ctx.mark_non_differentiable(dmask)
|
|
return (
|
|
(zmat.view(x0.shape), dmask)
|
|
if not prenorm
|
|
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
|
|
)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dz, *args):
|
|
# assert dz.is_contiguous()
|
|
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
|
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
|
|
# x0 is None if colscale is None
|
|
dropout_p = ctx.dropout_p
|
|
has_residual = ctx.has_residual
|
|
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
|
dz,
|
|
dx,
|
|
x,
|
|
x0,
|
|
dmask,
|
|
mu,
|
|
rsigma,
|
|
gamma,
|
|
rowscale,
|
|
colscale,
|
|
dropout_p,
|
|
has_residual,
|
|
ctx.is_rms_norm,
|
|
)
|
|
dx0 = dx0mat.view(x.shape)
|
|
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
|
dcolscale = rest[0] if colscale is not None else None
|
|
return (
|
|
dx0,
|
|
dresidual,
|
|
dgamma,
|
|
dbeta if ctx.has_beta else None,
|
|
None,
|
|
dcolscale,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
x0,
|
|
residual,
|
|
gamma,
|
|
beta,
|
|
colscale,
|
|
x0_subset,
|
|
out_subset,
|
|
dropout_p,
|
|
epsilon,
|
|
rowscale_const,
|
|
out_numrows,
|
|
residual_in_fp32=False,
|
|
prenorm=False,
|
|
is_rms_norm=False,
|
|
return_dmask=False,
|
|
):
|
|
x0 = maybe_align(x0.contiguous(), 16)
|
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
|
gamma = maybe_align(gamma.contiguous(), 16)
|
|
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
|
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
|
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
|
x0,
|
|
residual,
|
|
gamma,
|
|
beta,
|
|
colscale,
|
|
x0_subset,
|
|
out_subset,
|
|
dropout_p,
|
|
epsilon,
|
|
rowscale_const,
|
|
out_numrows,
|
|
residual_in_fp32,
|
|
is_rms_norm,
|
|
)
|
|
# Only need to save x0 if we need to compute gradient wrt colscale
|
|
x0_saved = x0 if colscale is not None else None
|
|
x_shape = (-1, *x0.shape[1:])
|
|
ctx.save_for_backward(
|
|
xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
|
|
)
|
|
ctx.prenorm = prenorm
|
|
ctx.dropout_p = dropout_p
|
|
ctx.rowscale_const = rowscale_const
|
|
ctx.x0_numrows = x0.shape[:-1].numel()
|
|
ctx.has_residual = residual is not None
|
|
ctx.is_rms_norm = is_rms_norm
|
|
ctx.has_beta = beta is not None
|
|
z_shape = (-1, *x0.shape[1:])
|
|
if not return_dmask:
|
|
return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
|
|
else:
|
|
z = zmat.view(z_shape)
|
|
dmask = (
|
|
dmask.view(x0.shape)
|
|
if dropout_p > 0.0
|
|
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
|
)
|
|
ctx.mark_non_differentiable(dmask)
|
|
return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dz, *args):
|
|
# assert dz.is_contiguous()
|
|
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
|
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
|
|
# x0 is None if colscale is None
|
|
dropout_p = ctx.dropout_p
|
|
has_residual = ctx.has_residual
|
|
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
|
dz,
|
|
dx,
|
|
x,
|
|
x0,
|
|
dmask,
|
|
mu,
|
|
rsigma,
|
|
gamma,
|
|
colscale,
|
|
x0_subset,
|
|
out_subset,
|
|
dropout_p,
|
|
ctx.rowscale_const,
|
|
ctx.x0_numrows,
|
|
has_residual,
|
|
ctx.is_rms_norm,
|
|
)
|
|
dx0 = dx0mat.view(-1, *x.shape[1:])
|
|
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
|
dcolscale = rest[0] if colscale is not None else None
|
|
return (
|
|
dx0,
|
|
dresidual,
|
|
dgamma,
|
|
dbeta if ctx.has_beta else None,
|
|
dcolscale,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
x0,
|
|
x1,
|
|
residual,
|
|
gamma0,
|
|
beta0,
|
|
gamma1,
|
|
beta1,
|
|
dropout_p,
|
|
epsilon,
|
|
residual_in_fp32=False,
|
|
prenorm=False,
|
|
is_rms_norm=False,
|
|
return_dmask=False,
|
|
):
|
|
x0 = maybe_align(x0.contiguous(), 16)
|
|
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
|
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
|
gamma0 = maybe_align(gamma0.contiguous(), 16)
|
|
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
|
|
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
|
|
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
|
|
(
|
|
z0mat,
|
|
z1mat,
|
|
xmat,
|
|
dmask0,
|
|
dmask1,
|
|
mu,
|
|
rsigma,
|
|
) = _dropout_add_layer_norm_parallel_residual_forward(
|
|
x0,
|
|
x1,
|
|
residual,
|
|
gamma0,
|
|
beta0,
|
|
gamma1,
|
|
beta1,
|
|
dropout_p,
|
|
epsilon,
|
|
residual_in_fp32,
|
|
is_rms_norm,
|
|
)
|
|
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
|
|
ctx.prenorm = prenorm
|
|
ctx.dropout_p = dropout_p
|
|
ctx.has_x1 = x1 is not None
|
|
ctx.has_residual = residual is not None
|
|
ctx.is_rms_norm = is_rms_norm
|
|
ctx.has_beta = beta0 is not None
|
|
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
|
|
if not return_dmask:
|
|
return z if not prenorm else (*z, xmat.view(x0.shape))
|
|
else:
|
|
dmask0 = (
|
|
dmask0.view(x0.shape)
|
|
if dropout_p > 0.0
|
|
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
|
)
|
|
dmask1 = (
|
|
dmask1.view(x0.shape)
|
|
if dropout_p > 0.0 and x1 is not None
|
|
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
|
)
|
|
ctx.mark_non_differentiable(dmask0)
|
|
ctx.mark_non_differentiable(dmask1)
|
|
return (
|
|
(*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
|
|
)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dz0, dz1, *args):
|
|
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
|
|
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
|
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
|
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
|
|
dropout_p = ctx.dropout_p
|
|
has_x1 = ctx.has_x1
|
|
has_residual = ctx.has_residual
|
|
(
|
|
dx0mat,
|
|
dx1mat,
|
|
dresidualmat,
|
|
dgamma0,
|
|
dbeta0,
|
|
dgamma1,
|
|
dbeta1,
|
|
) = _dropout_add_layer_norm_parallel_residual_backward(
|
|
dz0,
|
|
dz1,
|
|
dx,
|
|
x,
|
|
dmask0,
|
|
dmask1,
|
|
mu,
|
|
rsigma,
|
|
gamma0,
|
|
gamma1,
|
|
dropout_p,
|
|
has_x1,
|
|
has_residual,
|
|
ctx.is_rms_norm,
|
|
)
|
|
dx0 = dx0mat.view(x.shape)
|
|
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
|
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
|
return (
|
|
dx0,
|
|
dx1,
|
|
dresidual,
|
|
dgamma0,
|
|
dbeta0 if ctx.has_beta else None,
|
|
dgamma1,
|
|
dbeta1 if ctx.has_beta else None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
def layer_norm(x, weight, bias, epsilon):
|
|
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
|
|
|
|
|
|
def dropout_add_layer_norm(
|
|
x0,
|
|
residual,
|
|
weight,
|
|
bias,
|
|
dropout_p,
|
|
epsilon,
|
|
rowscale=None,
|
|
layerscale=None,
|
|
prenorm=False,
|
|
residual_in_fp32=False,
|
|
return_dropout_mask=False,
|
|
):
|
|
"""residual_in_fp32 only has an effect if residual is None.
|
|
Otherwise residual dtype is residual.dtype.
|
|
"""
|
|
return DropoutAddLayerNormFn.apply(
|
|
x0,
|
|
residual,
|
|
weight,
|
|
bias,
|
|
rowscale,
|
|
layerscale,
|
|
dropout_p,
|
|
epsilon,
|
|
residual_in_fp32,
|
|
prenorm,
|
|
False,
|
|
return_dropout_mask,
|
|
)
|
|
|
|
|
|
def dropout_add_layer_norm_subset(
|
|
x0,
|
|
residual,
|
|
weight,
|
|
bias,
|
|
dropout_p,
|
|
epsilon,
|
|
layerscale=None,
|
|
x0_subset=None,
|
|
out_subset=None,
|
|
rowscale_const=1.0,
|
|
out_numrows=0,
|
|
prenorm=False,
|
|
residual_in_fp32=False,
|
|
return_dropout_mask=False,
|
|
):
|
|
"""residual_in_fp32 only has an effect if residual is None.
|
|
Otherwise residual dtype is residual.dtype.
|
|
"""
|
|
return DropoutAddLayerNormSubsetFn.apply(
|
|
x0,
|
|
residual,
|
|
weight,
|
|
bias,
|
|
layerscale,
|
|
x0_subset,
|
|
out_subset,
|
|
dropout_p,
|
|
epsilon,
|
|
rowscale_const,
|
|
out_numrows,
|
|
residual_in_fp32,
|
|
prenorm,
|
|
False,
|
|
return_dropout_mask,
|
|
)
|
|
|
|
|
|
def dropout_add_layer_norm_parallel_residual(
|
|
x0,
|
|
x1,
|
|
residual,
|
|
weight0,
|
|
bias0,
|
|
weight1,
|
|
bias1,
|
|
dropout_p,
|
|
epsilon,
|
|
prenorm=False,
|
|
residual_in_fp32=False,
|
|
return_dropout_mask=False,
|
|
):
|
|
"""residual_in_fp32 only has an effect if residual is None.
|
|
Otherwise residual dtype is residual.dtype.
|
|
"""
|
|
return DropoutAddLayerNormParallelResidualFn.apply(
|
|
x0,
|
|
x1,
|
|
residual,
|
|
weight0,
|
|
bias0,
|
|
weight1,
|
|
bias1,
|
|
dropout_p,
|
|
epsilon,
|
|
residual_in_fp32,
|
|
prenorm,
|
|
False,
|
|
return_dropout_mask,
|
|
)
|
|
|
|
|
|
class DropoutAddLayerNorm(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
prenorm=False,
|
|
p=0.0,
|
|
eps=1e-5,
|
|
residual_in_fp32=False,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
self.prenorm = prenorm
|
|
self.p = p
|
|
self.eps = eps
|
|
self.residual_in_fp32 = residual_in_fp32
|
|
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
|
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
init.ones_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def forward(self, x0, residual=None):
|
|
return dropout_add_layer_norm(
|
|
x0,
|
|
residual,
|
|
self.weight,
|
|
self.bias,
|
|
self.p if self.training else 0.0,
|
|
self.eps,
|
|
prenorm=self.prenorm,
|
|
residual_in_fp32=self.residual_in_fp32,
|
|
)
|