[LayerNorm] Rename x1 -> residual

This commit is contained in:
Tri Dao 2023-01-19 13:07:27 -08:00
parent f68d41ec77
commit eb33e587e9
7 changed files with 89 additions and 88 deletions

View File

@ -59,7 +59,7 @@ struct ParamsBase {
// Common data pointers.
void *x0;
void *x1;
void *residual;
void *x;
void *dmask;
void *mu;
@ -117,7 +117,7 @@ struct BwdParams : public ParamsBase {
, dgamma_part(nullptr)
, dcolscale_part(nullptr)
, dx0(nullptr)
, dx1(nullptr)
, dresidual(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
, dcolscale(nullptr)
@ -136,7 +136,7 @@ struct BwdParams : public ParamsBase {
// Output: Dgrad.
void *dx0;
void *dx1;
void *dresidual;
// Output: Wgrad.
void *dbeta;
void *dgamma;

View File

@ -81,7 +81,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &beta_, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
@ -97,8 +97,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
bool is_rms_norm=false
) {
auto itype = x0.scalar_type();
auto rtype = x1_.has_value()
? x1_.value().scalar_type()
auto rtype = residual_.has_value()
? residual_.value().scalar_type()
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
auto wtype = gamma.scalar_type();
auto otype = itype;
@ -129,11 +129,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(gamma.sizes() == beta.sizes());
}
if (x1_.has_value()) {
auto x1 = x1_.value();
TORCH_CHECK(x1.is_cuda())
TORCH_CHECK(x1.is_contiguous());
TORCH_CHECK(x1.sizes() == sizes);
if (residual_.has_value()) {
auto residual = residual_.value();
TORCH_CHECK(residual.is_cuda())
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.sizes() == sizes);
}
if (rowscale_.has_value()) {
@ -178,7 +178,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto opts = x0.options();
bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
at::Tensor x;
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
at::Tensor dmask;
@ -194,7 +194,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
@ -383,8 +383,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto opts = x.options();
auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
at::Tensor dx1;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
at::Tensor dresidual;
if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
at::Tensor dcolscale;
@ -397,7 +397,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params.props = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
@ -450,7 +450,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launcher(launch_params, false);
std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
if (colscale_.has_value()) {
result.push_back(dcolscale);
result.push_back(dcolscale_part);
@ -462,7 +462,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA DropoutAddLayerNorm";
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"),
py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta"),
py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);

View File

@ -37,7 +37,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
extern __shared__ char smem_[];
const bool has_residual = params.dx1 != nullptr;
const bool has_residual = params.dresidual != nullptr;
const bool prenorm = params.dx != nullptr;
const index_t tidx = threadIdx.x;
@ -164,7 +164,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec dx0;
Rvec dx1;
Rvec dresidual;
Ivec x0;
if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
#pragma unroll
@ -178,7 +178,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
} else {
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
}
if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
if (save_dx0) {
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) {
@ -199,7 +199,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
}
}
}
if (has_residual) { dx1.store_to(params.dx1, idx_x); }
if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG;

View File

@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
const bool has_residual = params.x1 != nullptr;
const bool has_residual = params.residual != nullptr;
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[];
@ -111,11 +111,11 @@ void ln_fwd_kernel(FwdParams params) {
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec x0;
Rvec x1;
Rvec residual;
Rvec x;
Mvec dmask;
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
if (has_residual) { x1.load_from(params.x1, idx_x); }
if (has_residual) { residual.load_from(params.residual, idx_x); }
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
@ -127,9 +127,9 @@ void ln_fwd_kernel(FwdParams params) {
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
} else {
x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f;
x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;

View File

@ -292,7 +292,7 @@ class GPTModel(GPTPreTrainedModel):
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
# Set prenorm=False here since we don't need to the residual
# Set prenorm=False here since we don't need the residual
hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
@ -359,7 +359,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
if 'transformer.ln_0.weight' in state_dict:
n_layers = self.config.num_hidden_layers
n_layers = len(self.transformer.layers)
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
state_dict['transformer.ln_f.weight'] = ln_weight

View File

@ -7,20 +7,20 @@ from torch.nn import init
import dropout_layer_norm
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32=False, is_rms_norm=False):
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
1.0, 0, None, residual_in_fp32, is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
@ -28,7 +28,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
dropout_p, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
@ -39,34 +39,34 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
rowscale = rowscale.view(-1) if rowscale is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
dropout_p, 1.0, 0, has_residual, is_rms_norm
)
# dx1mat is None if not has_residual
# dresidualmat is None if not has_residual
if colscale is None:
return dx0mat, dx1mat, dgamma, dbeta
return dx0mat, dresidualmat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dx1mat, dgamma, dbeta, dcolscale
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
dropout_p, epsilon, rowscale_const, out_numrows,
residual_in_fp32=False, is_rms_norm=False):
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
out_subset, dropout_p, epsilon, rowscale_const,
out_numrows, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
@ -75,7 +75,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
x0_numrows, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
@ -87,30 +87,30 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
out_subset = out_subset.view(-1) if out_subset is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
)
# dx1mat is None if not has_residual
# dresidualmat is None if not has_residual
if colscale is None:
return dx0mat, dx1mat, dgamma, dbeta
return dx0mat, dresidualmat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dx1mat, dgamma, dbeta, dcolscale
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
residual = residual.contiguous() if residual is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous() if beta is not None else None
rowscale = rowscale.contiguous() if rowscale is not None else None
colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
@ -118,7 +118,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
if not return_dmask:
@ -140,29 +140,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
ctx.is_rms_norm
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None,
None, None, None, None)
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
None, None, None, None, None)
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32=False,
prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
residual = residual.contiguous() if residual is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous() if beta is not None else None
colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
@ -174,7 +174,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx.dropout_p = dropout_p
ctx.rowscale_const = rowscale_const
ctx.x0_numrows = x0.shape[:-1].numel()
ctx.has_residual = x1 is not None
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
z_shape = (-1, *x0.shape[1:])
@ -197,42 +197,42 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
)
dx0 = dx0mat.view(-1, *x.shape[1:])
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None,
None, None, None, None, None, None, None)
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
None, None, None, None, None, None, None, None)
def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False,
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
layerscale=None, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
False, return_dropout_mask
)
def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
)
@ -254,7 +254,7 @@ class DropoutAddLayerNorm(torch.nn.Module):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, x0, x1=None):
return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
def forward(self, x0, residual=None):
return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
self.p if self.training else 0.0, self.epsilon,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)

View File

@ -12,26 +12,27 @@ def rms_norm(x, weight, epsilon):
False, True)
def dropout_add_rms_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False, return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
layerscale=None, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
)
def dropout_add_rms_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
)
@ -52,7 +53,7 @@ class DropoutAddRMSNorm(torch.nn.Module):
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x0, x1=None):
return dropout_add_rms_norm(x0, x1, self.weight, None,
def forward(self, x0, residual=None):
return dropout_add_rms_norm(x0, residual, self.weight, None,
self.p if self.training else 0.0, self.epsilon,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)