[LayerNorm] Fuse LayerScale
This commit is contained in:
parent
8c6609ae1a
commit
ae137ed17a
@ -40,6 +40,8 @@ struct ParamsBase {
|
||||
, mu(nullptr)
|
||||
, rs(nullptr)
|
||||
, gamma(nullptr)
|
||||
, rowscale(nullptr)
|
||||
, colscale(nullptr)
|
||||
, dropout_keep_p(1.f)
|
||||
, dropout_scale(1.f)
|
||||
, workspace(nullptr)
|
||||
@ -63,6 +65,7 @@ struct ParamsBase {
|
||||
void *rs;
|
||||
void *gamma;
|
||||
void *rowscale;
|
||||
void *colscale;
|
||||
|
||||
float inverse_cols;
|
||||
|
||||
@ -106,10 +109,12 @@ struct BwdParams : public ParamsBase {
|
||||
, dx(nullptr)
|
||||
, dbeta_part(nullptr)
|
||||
, dgamma_part(nullptr)
|
||||
, dcolscale_part(nullptr)
|
||||
, dx0(nullptr)
|
||||
, dx1(nullptr)
|
||||
, dbeta(nullptr)
|
||||
, dgamma(nullptr)
|
||||
, dcolscale(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
@ -121,6 +126,7 @@ struct BwdParams : public ParamsBase {
|
||||
// Workspace for Wgrad pre-reduction.
|
||||
void *dbeta_part;
|
||||
void *dgamma_part;
|
||||
void *dcolscale_part;
|
||||
|
||||
// Output: Dgrad.
|
||||
void *dx0;
|
||||
@ -128,13 +134,14 @@ struct BwdParams : public ParamsBase {
|
||||
// Output: Wgrad.
|
||||
void *dbeta;
|
||||
void *dgamma;
|
||||
void *dcolscale;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
|
||||
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool, const bool)>;
|
||||
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
|
||||
using FunctionKey = uint64_t;
|
||||
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
|
||||
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
|
||||
|
||||
@ -84,6 +84,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
const at::Tensor &gamma, // hidden_size
|
||||
const at::Tensor &beta, // hidden_size
|
||||
c10::optional<const at::Tensor> &rowscale_, // BxS
|
||||
c10::optional<const at::Tensor> &colscale_, // BxS
|
||||
const float dropout_p,
|
||||
const float epsilon,
|
||||
c10::optional<at::Generator> gen_,
|
||||
@ -124,7 +125,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
TORCH_CHECK(rowscale.is_cuda())
|
||||
TORCH_CHECK(rowscale.is_contiguous());
|
||||
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
|
||||
TORCH_CHECK(rowscale.scalar_type() == itype);
|
||||
TORCH_CHECK(rowscale.dtype() == itype);
|
||||
}
|
||||
|
||||
if (colscale_.has_value()) {
|
||||
auto colscale = colscale_.value();
|
||||
TORCH_CHECK(colscale.is_cuda())
|
||||
TORCH_CHECK(colscale.is_contiguous());
|
||||
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
|
||||
TORCH_CHECK(colscale.dtype() == wtype);
|
||||
}
|
||||
|
||||
TORCH_CHECK(gamma.sizes() == beta.sizes());
|
||||
@ -135,7 +144,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
|
||||
auto opts = x0.options();
|
||||
|
||||
bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
|
||||
bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || (itype != rtype);
|
||||
at::Tensor x;
|
||||
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
|
||||
at::Tensor dmask;
|
||||
@ -153,6 +162,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
||||
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
|
||||
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
||||
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
@ -212,12 +222,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
|
||||
const at::Tensor &x, // BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &x0_, // BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
|
||||
const at::Tensor &mu, // BxS, FP32!
|
||||
const at::Tensor &rsigma, // BxS, FP32!
|
||||
const at::Tensor &gamma, // hidden_size
|
||||
c10::optional<const at::Tensor> &rowscale_, // BxS
|
||||
c10::optional<const at::Tensor> &colscale_, // BxS
|
||||
const float dropout_p,
|
||||
const bool has_residual
|
||||
) {
|
||||
@ -250,6 +263,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
auto rows = sizes[0];
|
||||
auto cols = sizes[1];
|
||||
|
||||
if (dx_.has_value()) {
|
||||
auto dx = dx_.value();
|
||||
TORCH_CHECK(dx.dtype() == rtype);
|
||||
TORCH_CHECK(dx.is_cuda())
|
||||
TORCH_CHECK(dx.is_contiguous());
|
||||
TORCH_CHECK(dx.sizes() == sizes);
|
||||
}
|
||||
|
||||
if (dmask_.has_value()) {
|
||||
auto dmask = dmask_.value();
|
||||
TORCH_CHECK(dmask.dtype() == mtype);
|
||||
@ -263,7 +284,22 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
TORCH_CHECK(rowscale.is_cuda())
|
||||
TORCH_CHECK(rowscale.is_contiguous());
|
||||
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
|
||||
TORCH_CHECK(rowscale.scalar_type() == itype);
|
||||
TORCH_CHECK(rowscale.dtype() == itype);
|
||||
}
|
||||
|
||||
if (colscale_.has_value()) {
|
||||
auto colscale = colscale_.value();
|
||||
TORCH_CHECK(colscale.is_cuda())
|
||||
TORCH_CHECK(colscale.is_contiguous());
|
||||
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
|
||||
TORCH_CHECK(colscale.dtype() == wtype);
|
||||
|
||||
TORCH_CHECK(x0_.has_value());
|
||||
auto x0 = x0_.value();
|
||||
TORCH_CHECK(x0.is_cuda())
|
||||
TORCH_CHECK(x0.is_contiguous());
|
||||
TORCH_CHECK(x0.sizes() == sizes);
|
||||
TORCH_CHECK(x0.dtype() == itype);
|
||||
}
|
||||
|
||||
auto hidden_size = gamma.numel();
|
||||
@ -282,6 +318,10 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
|
||||
auto dgamma = torch::empty_like(gamma);
|
||||
auto dbeta = torch::empty_like(gamma);
|
||||
at::Tensor dcolscale;
|
||||
if (colscale_.has_value()) {
|
||||
dcolscale = torch::empty_like(colscale_.value());
|
||||
}
|
||||
|
||||
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
|
||||
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
@ -290,31 +330,40 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
||||
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
|
||||
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
||||
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
||||
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
||||
|
||||
launcher(launch_params, true, /*prenorm=*/false);
|
||||
launcher(launch_params, true);
|
||||
|
||||
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
at::Tensor dcolscale_part;
|
||||
if (colscale_.has_value()) {
|
||||
dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
}
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
layer_norm::BwdParams ¶ms = launch_params.params;
|
||||
params.rows = rows;
|
||||
params.cols = cols;
|
||||
params.x = x.data_ptr();
|
||||
params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
|
||||
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
|
||||
params.mu = mu.data_ptr();
|
||||
params.rs = rsigma.data_ptr();
|
||||
params.gamma = gamma.data_ptr();
|
||||
params.dz = dz.data_ptr();
|
||||
params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
|
||||
params.dx0 = dx0.data_ptr();
|
||||
params.dbeta = dbeta.data_ptr();
|
||||
params.dgamma = dgamma.data_ptr();
|
||||
params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
|
||||
params.dbeta_part = dbeta_part.data_ptr();
|
||||
params.dgamma_part = dgamma_part.data_ptr();
|
||||
params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
|
||||
params.dropout_scale = 1.f / (1.f - dropout_p);
|
||||
params.inverse_cols = 1.f / float(params.cols);
|
||||
|
||||
@ -326,137 +375,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
|
||||
params.barrier = barrier.data_ptr<int>();
|
||||
}
|
||||
|
||||
launcher(launch_params, false, /*prenorm=*/false);
|
||||
launcher(launch_params, false);
|
||||
|
||||
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // BxSxhidden_size
|
||||
const at::Tensor &dx, // BxSxhidden_size
|
||||
const at::Tensor &x, // BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
|
||||
const at::Tensor &mu, // BxS, FP32!
|
||||
const at::Tensor &rsigma, // BxS, FP32!
|
||||
const at::Tensor &gamma, // hidden_size
|
||||
c10::optional<const at::Tensor> &rowscale_, // BxS
|
||||
const float dropout_p,
|
||||
const bool has_residual
|
||||
) {
|
||||
|
||||
auto itype = dz.scalar_type();
|
||||
auto rtype = x.scalar_type();
|
||||
auto wtype = gamma.scalar_type();
|
||||
auto otype = itype;
|
||||
auto ctype = torch::kFloat32;
|
||||
auto mtype = torch::kUInt8;
|
||||
|
||||
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
|
||||
|
||||
TORCH_CHECK(dz.dtype() == otype);
|
||||
TORCH_CHECK(dx.dtype() == rtype);
|
||||
TORCH_CHECK(mu.dtype() == ctype);
|
||||
TORCH_CHECK(rsigma.dtype() == ctype);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(dz.is_cuda());
|
||||
TORCH_CHECK(dx.is_cuda());
|
||||
TORCH_CHECK(mu.is_cuda());
|
||||
TORCH_CHECK(rsigma.is_cuda());
|
||||
TORCH_CHECK(gamma.is_cuda());
|
||||
|
||||
TORCH_CHECK(x.is_contiguous());
|
||||
TORCH_CHECK(dz.is_contiguous());
|
||||
TORCH_CHECK(dx.is_contiguous());
|
||||
|
||||
auto sizes = x.sizes();
|
||||
TORCH_CHECK(sizes.size() == 2);
|
||||
TORCH_CHECK(dz.sizes() == sizes);
|
||||
TORCH_CHECK(dx.sizes() == sizes);
|
||||
auto rows = sizes[0];
|
||||
auto cols = sizes[1];
|
||||
|
||||
if (dmask_.has_value()) {
|
||||
auto dmask = dmask_.value();
|
||||
TORCH_CHECK(dmask.dtype() == mtype);
|
||||
TORCH_CHECK(dmask.is_cuda());
|
||||
TORCH_CHECK(dmask.is_contiguous());
|
||||
TORCH_CHECK(dmask.sizes() == sizes);
|
||||
std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
|
||||
if (colscale_.has_value()) {
|
||||
result.push_back(dcolscale);
|
||||
result.push_back(dcolscale_part);
|
||||
}
|
||||
|
||||
if (rowscale_.has_value()) {
|
||||
auto rowscale = rowscale_.value();
|
||||
TORCH_CHECK(rowscale.is_cuda())
|
||||
TORCH_CHECK(rowscale.is_contiguous());
|
||||
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
|
||||
TORCH_CHECK(rowscale.scalar_type() == itype);
|
||||
}
|
||||
|
||||
auto hidden_size = gamma.numel();
|
||||
TORCH_CHECK(hidden_size == cols);
|
||||
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
|
||||
|
||||
TORCH_CHECK(mu.numel() == rows);
|
||||
TORCH_CHECK(mu.sizes() == rsigma.sizes());
|
||||
|
||||
TORCH_CHECK(gamma.numel() == cols);
|
||||
|
||||
auto opts = x.options();
|
||||
|
||||
auto dx0 = torch::empty_like(x, opts.dtype(itype));
|
||||
at::Tensor dx1;
|
||||
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
|
||||
auto dgamma = torch::empty_like(gamma);
|
||||
auto dbeta = torch::empty_like(gamma);
|
||||
|
||||
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
|
||||
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
||||
TORCH_CHECK(dropout_p < 1.f);
|
||||
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
||||
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
|
||||
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
|
||||
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
||||
|
||||
launcher(launch_params, true, /*prenorm=*/true);
|
||||
|
||||
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
layer_norm::BwdParams ¶ms = launch_params.params;
|
||||
params.rows = rows;
|
||||
params.cols = cols;
|
||||
params.x = x.data_ptr();
|
||||
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
|
||||
params.mu = mu.data_ptr();
|
||||
params.rs = rsigma.data_ptr();
|
||||
params.gamma = gamma.data_ptr();
|
||||
params.dz = dz.data_ptr();
|
||||
params.dx = dx.data_ptr();
|
||||
params.dx0 = dx0.data_ptr();
|
||||
params.dbeta = dbeta.data_ptr();
|
||||
params.dgamma = dgamma.data_ptr();
|
||||
params.dbeta_part = dbeta_part.data_ptr();
|
||||
params.dgamma_part = dgamma_part.data_ptr();
|
||||
params.dropout_scale = 1.f / (1.f - dropout_p);
|
||||
params.inverse_cols = 1.f / float(params.cols);
|
||||
|
||||
if( launch_params.barrier_size > 0 ) {
|
||||
// TODO Any way to avoid this?
|
||||
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
|
||||
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
|
||||
params.workspace = workspace.data_ptr();
|
||||
params.barrier = barrier.data_ptr<int>();
|
||||
}
|
||||
|
||||
launcher(launch_params, false, /*prenorm=*/true);
|
||||
|
||||
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
|
||||
return result;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -464,5 +390,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "CUDA DropoutAddLayerNorm";
|
||||
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
|
||||
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
|
||||
m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel");
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Is_even_cols>
|
||||
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
|
||||
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
||||
void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
|
||||
@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
|
||||
Cvec dzy_sum[LDGS];
|
||||
Cvec dz_sum[LDGS];
|
||||
Cvec dcolscale_sum[LDGS];
|
||||
|
||||
memset(dzy_sum, 0, sizeof(dzy_sum));
|
||||
memset(dz_sum, 0, sizeof(dz_sum));
|
||||
if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
|
||||
|
||||
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
|
||||
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
|
||||
@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
|
||||
|
||||
Wvec gamma[LDGS];
|
||||
Wvec colscale[LDGS];
|
||||
index_t idx = c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
gamma[it].load_from(params.gamma, idx);
|
||||
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
|
||||
idx += Ktraits::VEC_COLS_PER_LDG;
|
||||
}
|
||||
}
|
||||
@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
Ivec dx0;
|
||||
Rvec dx1;
|
||||
Ivec x0;
|
||||
if (Has_colscale) { x0.load_from(params.x0, idx); }
|
||||
#pragma unroll
|
||||
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
||||
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
|
||||
@ -140,9 +146,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
|
||||
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
|
||||
if (Is_dropout) {
|
||||
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
|
||||
dx0_tmp_res *= params.dropout_scale;
|
||||
if (Has_colscale) {
|
||||
dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
|
||||
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
|
||||
} else {
|
||||
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
|
||||
}
|
||||
} else {
|
||||
dx0.data.elt[jt] = dx0_tmp_res;
|
||||
if (Has_colscale) {
|
||||
dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
|
||||
dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
|
||||
} else {
|
||||
dx0.data.elt[jt] = dx0_tmp_res;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Has_residual) { dx1.store_to(params.dx1, idx); }
|
||||
@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
dz_sum[it].store_to(params.dbeta_part, idx);
|
||||
dzy_sum[it].store_to(params.dgamma_part, idx);
|
||||
if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
|
||||
idx += Ktraits::VEC_COLS_PER_LDG;
|
||||
}
|
||||
}
|
||||
@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
}
|
||||
}
|
||||
|
||||
compute_t cta_dcolscale_sum[NUM_RES];
|
||||
if (Has_colscale) {
|
||||
__syncthreads();
|
||||
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
dcolscale_sum[it].store_to(smem_wgrad, idx);
|
||||
idx += THREADS_PER_ROW;
|
||||
}
|
||||
__syncthreads();
|
||||
memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
|
||||
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
||||
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
||||
cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const index_t num_valid_writes
|
||||
= (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
|
||||
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
|
||||
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
|
||||
compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
|
||||
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
||||
if (Is_even_cols || (jt < num_valid_writes)) {
|
||||
*dgamma_part = cta_dzy_sum[jt];
|
||||
dgamma_part += Ktraits::THREADS_PER_CTA;
|
||||
*dbeta_part = cta_dz_sum[jt];
|
||||
dbeta_part += Ktraits::THREADS_PER_CTA;
|
||||
if (Has_colscale) {
|
||||
*dcolscale_part = cta_dcolscale_sum[jt];
|
||||
dcolscale_part += Ktraits::THREADS_PER_CTA;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_even_cols>
|
||||
template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
|
||||
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
|
||||
void ln_bwd_finalize_kernel(BwdParams params)
|
||||
{
|
||||
@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
||||
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
|
||||
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
|
||||
// Each thread sums over NUM_ELT columns.
|
||||
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
|
||||
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
|
||||
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
||||
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
||||
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
|
||||
if (Is_even_cols || col < params.cols) {
|
||||
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
|
||||
// index_t idx = row * Kernel_traits::COLS + col;
|
||||
index_t idx = row * params.cols + col;
|
||||
|
||||
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
|
||||
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
|
||||
dbeta_part.load_from(params.dbeta_part, idx);
|
||||
dgamma_part.load_from(params.dgamma_part, idx);
|
||||
if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
|
||||
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
|
||||
if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
|
||||
}
|
||||
}
|
||||
}
|
||||
void * smem_gamma = smem_;
|
||||
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
||||
void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
||||
|
||||
const int write_row = warp;
|
||||
const int write_col = lane ^ write_row;
|
||||
@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
||||
|
||||
dgamma_local.store_to(smem_gamma, write_idx);
|
||||
dbeta_local.store_to(smem_beta, write_idx);
|
||||
if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
|
||||
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
||||
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
|
||||
void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
||||
void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
|
||||
void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
|
||||
|
||||
|
||||
// More than one iter iff ROWS_PER_CTA < 32.
|
||||
@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
||||
|
||||
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
||||
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
||||
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
|
||||
|
||||
// Load beta and gamma transposed
|
||||
if(read_row < Kernel_traits::ROWS_PER_CTA){
|
||||
dbeta_local.load_from(smem_beta, read_idx);
|
||||
dgamma_local.load_from(smem_gamma, read_idx);
|
||||
if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
|
||||
}
|
||||
|
||||
// Call reducer on the loaded value(s) and convert.
|
||||
@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
||||
|
||||
dgamma_local.data.elt[it] = g_i;
|
||||
dbeta_local.data.elt[it] = b_i;
|
||||
if (Has_colscale) {
|
||||
compute_t cs_i = dcolscale_local.data.elt[it];
|
||||
cs_i = reducer.allreduce(cs_i, sum);
|
||||
dcolscale_local.data.elt[it] = cs_i;
|
||||
}
|
||||
}
|
||||
|
||||
// Leader stores the result at the current column.
|
||||
if(lane == 0){
|
||||
dgamma_local.store_to(smem_gamma_out, w);
|
||||
dbeta_local.store_to(smem_beta_out, w);
|
||||
if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
|
||||
}
|
||||
|
||||
}
|
||||
@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params)
|
||||
|
||||
using src_t = typename TypeToVec2<compute_t>::Type;
|
||||
using dst_t = typename TypeToVec2<weight_t>::Type;
|
||||
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
|
||||
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
|
||||
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
|
||||
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
|
||||
|
||||
dgamma_vec2.load_from(smem_gamma_out, lane);
|
||||
dbeta_vec2.load_from(smem_beta_out, lane);
|
||||
if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
|
||||
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
|
||||
if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
|
||||
}
|
||||
dgamma_out2.store_to(params.dgamma, col_out);
|
||||
dbeta_out2.store_to(params.dbeta, col_out);
|
||||
|
||||
if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -364,7 +420,7 @@ template<
|
||||
int BYTES_PER_LDG_MAIN,
|
||||
int BYTES_PER_LDG_FINAL
|
||||
>
|
||||
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
|
||||
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
|
||||
|
||||
using Kernel_traits = Kernel_traits<weight_t,
|
||||
input_t,
|
||||
@ -378,59 +434,64 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
|
||||
WARPS_N,
|
||||
BYTES_PER_LDG_MAIN
|
||||
>;
|
||||
bool prenorm = launch_params.params.dx != nullptr;
|
||||
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
|
||||
bool has_residual = launch_params.params.dx1 != nullptr;
|
||||
bool has_colscale = launch_params.params.colscale != nullptr;
|
||||
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
|
||||
BOOL_SWITCH(prenorm, PrenormConst, [&] {
|
||||
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
|
||||
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
||||
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
|
||||
if( configure_params ) {
|
||||
int ctas_per_sm;
|
||||
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
|
||||
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
||||
launch_params.barrier_size = 0;
|
||||
launch_params.workspace_bytes = 0;
|
||||
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
||||
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
||||
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
||||
* Kernel_traits::WARPS_M
|
||||
* Kernel_traits::CTAS_PER_ROW
|
||||
* sizeof(typename Kernel_traits::reduce_t)
|
||||
* 2;
|
||||
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
|
||||
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
||||
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
|
||||
if( configure_params ) {
|
||||
int ctas_per_sm;
|
||||
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
|
||||
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
||||
launch_params.barrier_size = 0;
|
||||
launch_params.workspace_bytes = 0;
|
||||
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
||||
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
||||
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
||||
* Kernel_traits::WARPS_M
|
||||
* Kernel_traits::CTAS_PER_ROW
|
||||
* sizeof(typename Kernel_traits::reduce_t)
|
||||
* 2;
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
|
||||
}
|
||||
auto stream = launch_params.stream;
|
||||
auto ctas_per_col = launch_params.params.ctas_per_col;
|
||||
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
|
||||
}
|
||||
auto stream = launch_params.stream;
|
||||
auto ctas_per_col = launch_params.params.ctas_per_col;
|
||||
|
||||
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
||||
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
|
||||
} else {
|
||||
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
||||
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
||||
void *params_ = (void *)&launch_params.params;
|
||||
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream);
|
||||
}
|
||||
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
||||
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
|
||||
} else {
|
||||
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
||||
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
||||
void *params_ = (void *)&launch_params.params;
|
||||
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream);
|
||||
}
|
||||
|
||||
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
|
||||
weight_t,
|
||||
input_t,
|
||||
residual_t,
|
||||
output_t,
|
||||
compute_t,
|
||||
index_t,
|
||||
32 * 32, // THREADS_PER_CTA
|
||||
BYTES_PER_LDG_FINAL>;
|
||||
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
|
||||
weight_t,
|
||||
input_t,
|
||||
residual_t,
|
||||
output_t,
|
||||
compute_t,
|
||||
index_t,
|
||||
HasColscaleConst,
|
||||
32 * 32, // THREADS_PER_CTA
|
||||
BYTES_PER_LDG_FINAL>;
|
||||
|
||||
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>;
|
||||
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
|
||||
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
|
||||
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Is_even_cols>
|
||||
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
|
||||
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
||||
void ln_fwd_kernel(FwdParams params) {
|
||||
|
||||
@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
|
||||
using Stats = typename Ktraits::Stats;
|
||||
using stats_t = typename Stats::stats_t;
|
||||
|
||||
constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same<input_t, residual_t>::value);
|
||||
const bool save_x = Has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || !(std::is_same<input_t, residual_t>::value);
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
@ -80,12 +80,14 @@ void ln_fwd_kernel(FwdParams params) {
|
||||
|
||||
Wvec gamma[LDGS];
|
||||
Wvec beta[LDGS];
|
||||
Wvec colscale[LDGS];
|
||||
index_t idx = c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
if (Is_even_cols || (it < num_valid_ldgs)) {
|
||||
gamma[it].load_from(params.gamma, idx);
|
||||
beta[it].load_from(params.beta, idx);
|
||||
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
|
||||
idx += VEC_COLS_PER_LDG;
|
||||
}
|
||||
}
|
||||
@ -109,13 +111,9 @@ void ln_fwd_kernel(FwdParams params) {
|
||||
// the more efficient curand_uniform4.
|
||||
mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
|
||||
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
|
||||
compute_t x_ij;
|
||||
if (Has_residual) {
|
||||
compute_t x1_ij = compute_t(x1.data.elt[jt]);
|
||||
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
|
||||
} else {
|
||||
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
|
||||
}
|
||||
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
|
||||
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
|
||||
compute_t x_ij = Has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
|
||||
if (save_x) { x.data.elt[jt] = x_ij; }
|
||||
xf[it * NUM_ELTS + jt] = x_ij;
|
||||
if (Is_dropout) { dmask.data.elt[jt] = keep; }
|
||||
@ -130,8 +128,8 @@ void ln_fwd_kernel(FwdParams params) {
|
||||
const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
|
||||
const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
|
||||
const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
|
||||
// Need to convert to int, otherwise the subtraction will wrap around.
|
||||
auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
|
||||
// Need to convert to int, otherwise the subtraction will wrap around.
|
||||
const index_t valid_partial_vecs_in_warp =
|
||||
std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
|
||||
int(THREADS_PER_WARP));
|
||||
@ -206,45 +204,48 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
|
||||
BYTES_PER_LDG
|
||||
>;
|
||||
bool has_residual = launch_params.params.x1 != nullptr;
|
||||
bool has_colscale = launch_params.params.colscale != nullptr;
|
||||
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
|
||||
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
|
||||
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
||||
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
|
||||
if( configure_params ) {
|
||||
int ctas_per_sm;
|
||||
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
|
||||
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
||||
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
|
||||
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
|
||||
launch_params.barrier_size = 0;
|
||||
launch_params.workspace_bytes = 0;
|
||||
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
||||
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
||||
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
||||
* Kernel_traits::WARPS_M
|
||||
* Kernel_traits::CTAS_PER_ROW
|
||||
* sizeof(typename Kernel_traits::Stats::stats_t)
|
||||
* 2;
|
||||
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
|
||||
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
|
||||
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
|
||||
if( configure_params ) {
|
||||
int ctas_per_sm;
|
||||
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
|
||||
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
||||
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
|
||||
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
|
||||
launch_params.barrier_size = 0;
|
||||
launch_params.workspace_bytes = 0;
|
||||
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
||||
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
||||
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
||||
* Kernel_traits::WARPS_M
|
||||
* Kernel_traits::CTAS_PER_ROW
|
||||
* sizeof(typename Kernel_traits::Stats::stats_t)
|
||||
* 2;
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
|
||||
}
|
||||
auto stream = launch_params.stream;
|
||||
auto ctas_per_col = launch_params.params.ctas_per_col;
|
||||
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
|
||||
}
|
||||
auto stream = launch_params.stream;
|
||||
auto ctas_per_col = launch_params.params.ctas_per_col;
|
||||
|
||||
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
||||
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
|
||||
} else {
|
||||
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
||||
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
||||
void *params_ = (void *)&launch_params.params;
|
||||
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream);
|
||||
}
|
||||
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
||||
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
|
||||
} else {
|
||||
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
||||
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
||||
void *params_ = (void *)&launch_params.params;
|
||||
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -38,6 +38,7 @@ template<
|
||||
typename output_t_,
|
||||
typename compute_t_,
|
||||
typename index_t_,
|
||||
bool Has_colscale,
|
||||
uint32_t THREADS_PER_CTA_,
|
||||
uint32_t BYTES_PER_LDG_,
|
||||
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
|
||||
@ -69,7 +70,8 @@ struct Kernel_traits_finalize : public Base {
|
||||
// Shared memory size to coalsece the CTA result.
|
||||
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
|
||||
// Shared memory requirement per CTA.
|
||||
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
|
||||
static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
|
||||
enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
|
||||
|
||||
// The type of the reducer.
|
||||
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
|
||||
|
||||
@ -45,7 +45,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
|
||||
#define REGISTER_BWD_LAUNCHER( \
|
||||
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
|
||||
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
|
||||
const bool configure_params, const bool prenorm) { \
|
||||
const bool configure_params) { \
|
||||
launch_<WTYPE, \
|
||||
ITYPE, \
|
||||
RTYPE, \
|
||||
@ -57,7 +57,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
|
||||
WARPS_M, \
|
||||
WARPS_N, \
|
||||
BYTES_PER_LDG, \
|
||||
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params, prenorm); \
|
||||
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
|
||||
} \
|
||||
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
|
||||
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
||||
import dropout_layer_norm
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon,
|
||||
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32):
|
||||
""" Assume that arguments are contiguous
|
||||
"""
|
||||
@ -14,133 +16,98 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, ep
|
||||
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32
|
||||
x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32
|
||||
)
|
||||
# dmask is None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
|
||||
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p,
|
||||
has_residual):
|
||||
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
|
||||
dropout_p, has_residual):
|
||||
""" Assume that arguments are contiguous
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + x1 was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
"""
|
||||
# dmask is None if dropout_p == 0.0
|
||||
hidden_size = gamma.numel()
|
||||
xmat = x.view((-1, hidden_size))
|
||||
dzmat = dz.view(xmat.shape)
|
||||
dxmat = dx.view(xmat.shape) if dx is not None else None
|
||||
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
|
||||
colscale = colscale.view(-1) if colscale is not None else None
|
||||
if colscale is not None:
|
||||
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
||||
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p,
|
||||
has_residual
|
||||
)
|
||||
# dx1mat is None if not has_residual
|
||||
return dx0mat, dx1mat, dgamma, dbeta
|
||||
if colscale is None:
|
||||
return dx0mat, dx1mat, dgamma, dbeta
|
||||
else:
|
||||
dcolscale = rest[0]
|
||||
return dx0mat, dx1mat, dgamma, dbeta, dcolscale
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
|
||||
dropout_p, has_residual):
|
||||
""" Assume that arguments are contiguous
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
xmat = x.view((-1, hidden_size))
|
||||
dzmat = dz.view(xmat.shape)
|
||||
dxmat = dx.view(xmat.shape)
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd(
|
||||
dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
|
||||
)
|
||||
return dx0mat, dx1mat, dgamma, dbeta
|
||||
|
||||
|
||||
class DropoutAddLayerNormFN(torch.autograd.Function):
|
||||
class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
|
||||
return_dmask=False):
|
||||
def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32,
|
||||
prenorm=False, return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
x1 = x1.contiguous() if x1 is not None else None
|
||||
gamma = gamma.contiguous()
|
||||
beta = beta.contiguous()
|
||||
rowscale = rowscale.contiguous() if rowscale is not None else None
|
||||
colscale = colscale.contiguous() if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
|
||||
x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32
|
||||
)
|
||||
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
|
||||
# Only need to save x0 if we need to compute gradient wrt colscale
|
||||
x0_saved = x0 if colscale is not None else None
|
||||
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
|
||||
ctx.prenorm = prenorm
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.has_residual = x1 is not None
|
||||
if not return_dmask:
|
||||
return zmat.view(x0.shape)
|
||||
return (zmat.view(x0.shape) if not prenorm
|
||||
else (zmat.view(x0.shape), xmat.view(x0.shape)))
|
||||
else:
|
||||
dmask = (dmask.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
ctx.mark_non_differentiable(dmask)
|
||||
return zmat.view(x0.shape), dmask
|
||||
return ((zmat.view(x0.shape), dmask) if not prenorm
|
||||
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz, *args):
|
||||
# assert dz.is_contiguous()
|
||||
dz = dz.contiguous() # this happens!
|
||||
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
|
||||
dx = args[0].contiguous() if ctx.prenorm else None
|
||||
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
|
||||
# x0 is None if colscale is None
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
|
||||
dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
|
||||
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
||||
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
||||
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
|
||||
dcolscale = rest[0] if colscale is not None else None
|
||||
return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
|
||||
|
||||
|
||||
class DropoutAddLayerNormPrenormFN(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
|
||||
return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
x1 = x1.contiguous() if x1 is not None else None
|
||||
gamma = gamma.contiguous()
|
||||
beta = beta.contiguous()
|
||||
rowscale = rowscale.contiguous() if rowscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
|
||||
)
|
||||
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.has_residual = x1 is not None
|
||||
if not return_dmask:
|
||||
return zmat.view(x0.shape), xmat.view(x0.shape)
|
||||
else:
|
||||
dmask = (dmask.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
ctx.mark_non_differentiable(dmask)
|
||||
return zmat.view(x0.shape), xmat.view(x0.shape), dmask
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz, dx, *args):
|
||||
# assert dz.is_contiguous()
|
||||
dz = dz.contiguous() # this happens!
|
||||
dx = dx.contiguous() # this happens!
|
||||
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward(
|
||||
dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
||||
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
|
||||
|
||||
|
||||
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None,
|
||||
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
|
||||
prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if x1 is None.
|
||||
Otherwise residual dtype is x1.dtype.
|
||||
"""
|
||||
args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32,
|
||||
return_dropout_mask)
|
||||
if not prenorm:
|
||||
return DropoutAddLayerNormFN.apply(*args)
|
||||
else:
|
||||
return DropoutAddLayerNormPrenormFN.apply(*args)
|
||||
return DropoutAddLayerNormFn.apply(
|
||||
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
return_dropout_mask
|
||||
)
|
||||
|
||||
|
||||
class DropoutAddLayerNorm(torch.nn.Module):
|
||||
|
||||
@ -11,6 +11,7 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
@pytest.mark.parametrize('has_colscale', [True, False])
|
||||
@pytest.mark.parametrize('has_rowscale', [True, False])
|
||||
# @pytest.mark.parametrize('has_rowscale', [True])
|
||||
@pytest.mark.parametrize('has_residual', [True, False])
|
||||
@ -26,12 +27,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
|
||||
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
|
||||
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
|
||||
dropout_p, has_residual, has_rowscale):
|
||||
dropout_p, has_residual, has_rowscale, has_colscale):
|
||||
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
|
||||
pytest.skip() # Not supported
|
||||
# Backward numerical error is high, and this case isn't used
|
||||
if has_rowscale and not has_residual:
|
||||
pytest.skip()
|
||||
device = 'cuda'
|
||||
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
|
||||
rtol, atol = (1e-3, 1e-4)
|
||||
@ -43,6 +41,12 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
|
||||
requires_grad=True)
|
||||
x0 = x0_pt.detach().clone().requires_grad_()
|
||||
x0_ref = x0_pt.detach().clone().float().requires_grad_()
|
||||
if has_colscale:
|
||||
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
|
||||
colscale_pt = colscale.detach().clone().requires_grad_()
|
||||
colscale_ref = colscale.detach().clone().float().requires_grad_()
|
||||
else:
|
||||
colscale = None
|
||||
if has_residual:
|
||||
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
|
||||
x1 = x1_pt.detach().clone().requires_grad_()
|
||||
@ -59,6 +63,9 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
|
||||
rowscale = None
|
||||
x0_scaled_pt = x0_pt
|
||||
x0_scaled_ref = x0_ref
|
||||
if has_colscale:
|
||||
x0_scaled_pt = x0_scaled_pt * colscale_pt
|
||||
x0_scaled_ref = x0_scaled_ref * colscale_ref
|
||||
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
|
||||
torch.nn.init.normal_(model_pt.weight)
|
||||
torch.nn.init.normal_(model_pt.bias)
|
||||
@ -71,7 +78,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
|
||||
model_ref.bias.copy_(model_pt.bias)
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
|
||||
model.epsilon, rowscale=rowscale,
|
||||
model.epsilon, rowscale=rowscale, layerscale=colscale,
|
||||
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
|
||||
assert out.dtype == input_dtype
|
||||
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
|
||||
@ -94,6 +101,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
|
||||
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
|
||||
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
|
||||
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
|
||||
if has_colscale:
|
||||
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
|
||||
|
||||
|
||||
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
|
||||
@ -139,6 +148,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
|
||||
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
|
||||
|
||||
|
||||
@pytest.mark.parametrize('has_colscale', [True, False])
|
||||
@pytest.mark.parametrize('has_rowscale', [True, False])
|
||||
@pytest.mark.parametrize('has_residual', [True, False])
|
||||
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
|
||||
@ -147,20 +157,17 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32),
|
||||
(torch.float32, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('has_colscale', [True])
|
||||
# @pytest.mark.parametrize('has_rowscale', [False])
|
||||
# @pytest.mark.parametrize('has_residual', [True])
|
||||
# @pytest.mark.parametrize('has_residual', [False])
|
||||
# @pytest.mark.parametrize('dropout_p', [0.0])
|
||||
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
|
||||
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
|
||||
# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
|
||||
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
|
||||
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
|
||||
dropout_p, has_residual, has_rowscale):
|
||||
dropout_p, has_residual, has_rowscale, has_colscale):
|
||||
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
|
||||
pytest.skip() # Not supported
|
||||
# Backward numerical error is high, and this case isn't used
|
||||
if has_rowscale and not has_residual:
|
||||
pytest.skip()
|
||||
device = 'cuda'
|
||||
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
|
||||
rtol, atol = (1e-3, 2e-4)
|
||||
@ -172,6 +179,12 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
|
||||
requires_grad=True)
|
||||
x0 = x0_pt.detach().clone().requires_grad_()
|
||||
x0_ref = x0_pt.detach().clone().float().requires_grad_()
|
||||
if has_colscale:
|
||||
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
|
||||
colscale_pt = colscale.detach().clone().requires_grad_()
|
||||
colscale_ref = colscale.detach().clone().float().requires_grad_()
|
||||
else:
|
||||
colscale = None
|
||||
if has_residual:
|
||||
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
|
||||
x1 = x1_pt.detach().clone().requires_grad_()
|
||||
@ -188,6 +201,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
|
||||
rowscale = None
|
||||
x0_scaled_pt = x0_pt
|
||||
x0_scaled_ref = x0_ref
|
||||
if has_colscale:
|
||||
x0_scaled_pt = x0_scaled_pt * colscale_pt
|
||||
x0_scaled_ref = x0_scaled_ref * colscale_ref
|
||||
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
|
||||
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
|
||||
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
|
||||
@ -199,7 +215,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
|
||||
model_ref.bias.copy_(model_pt.bias)
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
|
||||
model.epsilon, rowscale=rowscale, prenorm=True,
|
||||
model.epsilon, rowscale=rowscale,
|
||||
layerscale=colscale, prenorm=True,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
return_dropout_mask=True)
|
||||
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
|
||||
@ -225,6 +242,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
|
||||
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
|
||||
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
|
||||
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
|
||||
if has_colscale:
|
||||
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
|
||||
|
||||
|
||||
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user