[LayerNorm] Rename x1 -> residual
This commit is contained in:
parent
f68d41ec77
commit
eb33e587e9
@ -59,7 +59,7 @@ struct ParamsBase {
|
||||
|
||||
// Common data pointers.
|
||||
void *x0;
|
||||
void *x1;
|
||||
void *residual;
|
||||
void *x;
|
||||
void *dmask;
|
||||
void *mu;
|
||||
@ -117,7 +117,7 @@ struct BwdParams : public ParamsBase {
|
||||
, dgamma_part(nullptr)
|
||||
, dcolscale_part(nullptr)
|
||||
, dx0(nullptr)
|
||||
, dx1(nullptr)
|
||||
, dresidual(nullptr)
|
||||
, dbeta(nullptr)
|
||||
, dgamma(nullptr)
|
||||
, dcolscale(nullptr)
|
||||
@ -136,7 +136,7 @@ struct BwdParams : public ParamsBase {
|
||||
|
||||
// Output: Dgrad.
|
||||
void *dx0;
|
||||
void *dx1;
|
||||
void *dresidual;
|
||||
// Output: Wgrad.
|
||||
void *dbeta;
|
||||
void *dgamma;
|
||||
|
||||
@ -81,7 +81,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &residual_, // Residual: BxSxhidden_size
|
||||
const at::Tensor &gamma, // hidden_size
|
||||
c10::optional<const at::Tensor> &beta_, // hidden_size
|
||||
c10::optional<const at::Tensor> &rowscale_, // BxS
|
||||
@ -97,8 +97,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
bool is_rms_norm=false
|
||||
) {
|
||||
auto itype = x0.scalar_type();
|
||||
auto rtype = x1_.has_value()
|
||||
? x1_.value().scalar_type()
|
||||
auto rtype = residual_.has_value()
|
||||
? residual_.value().scalar_type()
|
||||
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
|
||||
auto wtype = gamma.scalar_type();
|
||||
auto otype = itype;
|
||||
@ -129,11 +129,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
TORCH_CHECK(gamma.sizes() == beta.sizes());
|
||||
}
|
||||
|
||||
if (x1_.has_value()) {
|
||||
auto x1 = x1_.value();
|
||||
TORCH_CHECK(x1.is_cuda())
|
||||
TORCH_CHECK(x1.is_contiguous());
|
||||
TORCH_CHECK(x1.sizes() == sizes);
|
||||
if (residual_.has_value()) {
|
||||
auto residual = residual_.value();
|
||||
TORCH_CHECK(residual.is_cuda())
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
TORCH_CHECK(residual.sizes() == sizes);
|
||||
}
|
||||
|
||||
if (rowscale_.has_value()) {
|
||||
@ -178,7 +178,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
|
||||
auto opts = x0.options();
|
||||
|
||||
bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
|
||||
bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
|
||||
at::Tensor x;
|
||||
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
|
||||
at::Tensor dmask;
|
||||
@ -194,7 +194,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(dropout_p < 1.f);
|
||||
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
||||
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
|
||||
launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;
|
||||
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
||||
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
|
||||
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
|
||||
@ -383,8 +383,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
auto opts = x.options();
|
||||
|
||||
auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
|
||||
at::Tensor dx1;
|
||||
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
|
||||
at::Tensor dresidual;
|
||||
if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }
|
||||
auto dgamma = torch::empty_like(gamma);
|
||||
auto dbeta = torch::empty_like(gamma);
|
||||
at::Tensor dcolscale;
|
||||
@ -397,7 +397,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
||||
TORCH_CHECK(dropout_p < 1.f);
|
||||
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
||||
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
|
||||
launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;
|
||||
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
||||
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
|
||||
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
|
||||
@ -450,7 +450,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
|
||||
launcher(launch_params, false);
|
||||
|
||||
std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
|
||||
std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };
|
||||
if (colscale_.has_value()) {
|
||||
result.push_back(dcolscale);
|
||||
result.push_back(dcolscale_part);
|
||||
@ -462,7 +462,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "CUDA DropoutAddLayerNorm";
|
||||
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
|
||||
py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"),
|
||||
py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta"),
|
||||
py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
|
||||
py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
|
||||
py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
|
||||
|
||||
@ -37,7 +37,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const bool has_residual = params.dx1 != nullptr;
|
||||
const bool has_residual = params.dresidual != nullptr;
|
||||
const bool prenorm = params.dx != nullptr;
|
||||
|
||||
const index_t tidx = threadIdx.x;
|
||||
@ -164,7 +164,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
Ivec dx0;
|
||||
Rvec dx1;
|
||||
Rvec dresidual;
|
||||
Ivec x0;
|
||||
if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
|
||||
#pragma unroll
|
||||
@ -178,7 +178,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
} else {
|
||||
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
|
||||
}
|
||||
if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
|
||||
if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }
|
||||
if (save_dx0) {
|
||||
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
|
||||
if (Is_dropout) {
|
||||
@ -199,7 +199,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (has_residual) { dx1.store_to(params.dx1, idx_x); }
|
||||
if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }
|
||||
if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
|
||||
idx_x += Ktraits::VEC_COLS_PER_LDG;
|
||||
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
|
||||
|
||||
@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
|
||||
using Stats = typename Ktraits::Stats;
|
||||
using stats_t = typename Stats::stats_t;
|
||||
|
||||
const bool has_residual = params.x1 != nullptr;
|
||||
const bool has_residual = params.residual != nullptr;
|
||||
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
@ -111,11 +111,11 @@ void ln_fwd_kernel(FwdParams params) {
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
Ivec x0;
|
||||
Rvec x1;
|
||||
Rvec residual;
|
||||
Rvec x;
|
||||
Mvec dmask;
|
||||
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
|
||||
if (has_residual) { x1.load_from(params.x1, idx_x); }
|
||||
if (has_residual) { residual.load_from(params.residual, idx_x); }
|
||||
#pragma unroll
|
||||
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
||||
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
|
||||
@ -127,9 +127,9 @@ void ln_fwd_kernel(FwdParams params) {
|
||||
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
|
||||
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
|
||||
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
|
||||
x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
|
||||
x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;
|
||||
} else {
|
||||
x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f;
|
||||
x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;
|
||||
}
|
||||
if (save_x) { x.data.elt[jt] = x_ij; }
|
||||
xf[it * NUM_ELTS + jt] = x_ij;
|
||||
|
||||
@ -292,7 +292,7 @@ class GPTModel(GPTPreTrainedModel):
|
||||
residual = (dropped + residual) if residual is not None else dropped
|
||||
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
|
||||
else:
|
||||
# Set prenorm=False here since we don't need to the residual
|
||||
# Set prenorm=False here since we don't need the residual
|
||||
hidden_states = dropout_add_layer_norm(
|
||||
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
|
||||
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
|
||||
@ -359,7 +359,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
# Previous: Attn / MLP -> Dropout -> Add -> LN
|
||||
# Current: Dropout -> Add -> LN -> Attn / MLP
|
||||
if 'transformer.ln_0.weight' in state_dict:
|
||||
n_layers = self.config.num_hidden_layers
|
||||
n_layers = len(self.transformer.layers)
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
|
||||
state_dict['transformer.ln_f.weight'] = ln_weight
|
||||
|
||||
@ -7,20 +7,20 @@ from torch.nn import init
|
||||
import dropout_layer_norm
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32=False, is_rms_norm=False):
|
||||
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
|
||||
epsilon, residual_in_fp32=False, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
||||
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
|
||||
x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
|
||||
1.0, 0, None, residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# dmask is None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
|
||||
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
||||
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
|
||||
dropout_p, has_residual, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + x1 was not returned in the fwd).
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
@ -39,34 +39,34 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
if colscale is not None:
|
||||
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
||||
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
|
||||
dropout_p, 1.0, 0, has_residual, is_rms_norm
|
||||
)
|
||||
# dx1mat is None if not has_residual
|
||||
# dresidualmat is None if not has_residual
|
||||
if colscale is None:
|
||||
return dx0mat, dx1mat, dgamma, dbeta
|
||||
return dx0mat, dresidualmat, dgamma, dbeta
|
||||
else:
|
||||
dcolscale = rest[0]
|
||||
return dx0mat, dx1mat, dgamma, dbeta, dcolscale
|
||||
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
|
||||
dropout_p, epsilon, rowscale_const, out_numrows,
|
||||
residual_in_fp32=False, is_rms_norm=False):
|
||||
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
|
||||
out_subset, dropout_p, epsilon, rowscale_const,
|
||||
out_numrows, residual_in_fp32=False, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
||||
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
||||
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
||||
out_subset = out_subset.view(-1) if out_subset is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# dmask is None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
|
||||
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
||||
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
|
||||
x0_numrows, has_residual, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + x1 was not returned in the fwd).
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
@ -87,30 +87,30 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
|
||||
out_subset = out_subset.view(-1) if out_subset is not None else None
|
||||
if colscale is not None:
|
||||
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
||||
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
|
||||
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
|
||||
)
|
||||
# dx1mat is None if not has_residual
|
||||
# dresidualmat is None if not has_residual
|
||||
if colscale is None:
|
||||
return dx0mat, dx1mat, dgamma, dbeta
|
||||
return dx0mat, dresidualmat, dgamma, dbeta
|
||||
else:
|
||||
dcolscale = rest[0]
|
||||
return dx0mat, dx1mat, dgamma, dbeta, dcolscale
|
||||
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
||||
|
||||
|
||||
class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
x1 = x1.contiguous() if x1 is not None else None
|
||||
residual = residual.contiguous() if residual is not None else None
|
||||
gamma = gamma.contiguous()
|
||||
beta = beta.contiguous() if beta is not None else None
|
||||
rowscale = rowscale.contiguous() if rowscale is not None else None
|
||||
colscale = colscale.contiguous() if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||
x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# Only need to save x0 if we need to compute gradient wrt colscale
|
||||
@ -118,7 +118,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
|
||||
ctx.prenorm = prenorm
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.has_residual = x1 is not None
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_beta = beta is not None
|
||||
if not return_dmask:
|
||||
@ -140,29 +140,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
# x0 is None if colscale is None
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
||||
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
||||
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
|
||||
ctx.is_rms_norm
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
||||
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
||||
dcolscale = rest[0] if colscale is not None else None
|
||||
return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None,
|
||||
None, None, None, None)
|
||||
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
|
||||
None, None, None, None, None)
|
||||
|
||||
|
||||
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32=False,
|
||||
prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
x1 = x1.contiguous() if x1 is not None else None
|
||||
residual = residual.contiguous() if residual is not None else None
|
||||
gamma = gamma.contiguous()
|
||||
beta = beta.contiguous() if beta is not None else None
|
||||
colscale = colscale.contiguous() if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
||||
x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
|
||||
)
|
||||
# Only need to save x0 if we need to compute gradient wrt colscale
|
||||
@ -174,7 +174,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.rowscale_const = rowscale_const
|
||||
ctx.x0_numrows = x0.shape[:-1].numel()
|
||||
ctx.has_residual = x1 is not None
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_beta = beta is not None
|
||||
z_shape = (-1, *x0.shape[1:])
|
||||
@ -197,42 +197,42 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
# x0 is None if colscale is None
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
||||
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
||||
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
|
||||
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
|
||||
)
|
||||
dx0 = dx0mat.view(-1, *x.shape[1:])
|
||||
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
||||
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
||||
dcolscale = rest[0] if colscale is not None else None
|
||||
return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None,
|
||||
None, None, None, None, None, None, None)
|
||||
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
|
||||
None, None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
def layer_norm(x, weight, bias, epsilon):
|
||||
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
|
||||
|
||||
|
||||
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
|
||||
prenorm=False, residual_in_fp32=False,
|
||||
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
|
||||
layerscale=None, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if x1 is None.
|
||||
Otherwise residual dtype is x1.dtype.
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormFn.apply(
|
||||
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
False, return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
|
||||
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
|
||||
x0_subset=None, out_subset=None, rowscale_const=1.0,
|
||||
out_numrows=0, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if x1 is None.
|
||||
Otherwise residual dtype is x1.dtype.
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormSubsetFn.apply(
|
||||
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
|
||||
)
|
||||
|
||||
@ -254,7 +254,7 @@ class DropoutAddLayerNorm(torch.nn.Module):
|
||||
init.ones_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x0, x1=None):
|
||||
return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
|
||||
def forward(self, x0, residual=None):
|
||||
return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
|
||||
self.p if self.training else 0.0, self.epsilon,
|
||||
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|
||||
|
||||
@ -12,26 +12,27 @@ def rms_norm(x, weight, epsilon):
|
||||
False, True)
|
||||
|
||||
|
||||
def dropout_add_rms_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
|
||||
prenorm=False, residual_in_fp32=False, return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if x1 is None.
|
||||
Otherwise residual dtype is x1.dtype.
|
||||
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
|
||||
layerscale=None, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormFn.apply(
|
||||
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
True, return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_rms_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
|
||||
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
|
||||
x0_subset=None, out_subset=None, rowscale_const=1.0,
|
||||
out_numrows=0, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if x1 is None.
|
||||
Otherwise residual dtype is x1.dtype.
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormSubsetFn.apply(
|
||||
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
|
||||
)
|
||||
|
||||
@ -52,7 +53,7 @@ class DropoutAddRMSNorm(torch.nn.Module):
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, x0, x1=None):
|
||||
return dropout_add_rms_norm(x0, x1, self.weight, None,
|
||||
def forward(self, x0, residual=None):
|
||||
return dropout_add_rms_norm(x0, residual, self.weight, None,
|
||||
self.p if self.training else 0.0, self.epsilon,
|
||||
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user