[LayerNorm] Fuse LayerScale

This commit is contained in:
Tri Dao 2022-12-10 20:29:05 -08:00
parent 8c6609ae1a
commit ae137ed17a
8 changed files with 310 additions and 328 deletions

View File

@ -40,6 +40,8 @@ struct ParamsBase {
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, rowscale(nullptr)
, colscale(nullptr)
, dropout_keep_p(1.f)
, dropout_scale(1.f)
, workspace(nullptr)
@ -63,6 +65,7 @@ struct ParamsBase {
void *rs;
void *gamma;
void *rowscale;
void *colscale;
float inverse_cols;
@ -106,10 +109,12 @@ struct BwdParams : public ParamsBase {
, dx(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dcolscale_part(nullptr)
, dx0(nullptr)
, dx1(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
, dcolscale(nullptr)
{
}
@ -121,6 +126,7 @@ struct BwdParams : public ParamsBase {
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
void *dcolscale_part;
// Output: Dgrad.
void *dx0;
@ -128,13 +134,14 @@ struct BwdParams : public ParamsBase {
// Output: Wgrad.
void *dbeta;
void *dgamma;
void *dcolscale;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;

View File

@ -84,6 +84,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // BxS
const float dropout_p,
const float epsilon,
c10::optional<at::Generator> gen_,
@ -124,7 +125,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
TORCH_CHECK(rowscale.dtype() == itype);
}
if (colscale_.has_value()) {
auto colscale = colscale_.value();
TORCH_CHECK(colscale.is_cuda())
TORCH_CHECK(colscale.is_contiguous());
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
TORCH_CHECK(colscale.dtype() == wtype);
}
TORCH_CHECK(gamma.sizes() == beta.sizes());
@ -135,7 +144,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto opts = x0.options();
bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || (itype != rtype);
at::Tensor x;
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
at::Tensor dmask;
@ -153,6 +162,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
@ -212,12 +222,15 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
c10::optional<const at::Tensor> &dx_, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
c10::optional<const at::Tensor> &x0_, // BxSxhidden_size
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // BxS
const float dropout_p,
const bool has_residual
) {
@ -250,6 +263,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto rows = sizes[0];
auto cols = sizes[1];
if (dx_.has_value()) {
auto dx = dx_.value();
TORCH_CHECK(dx.dtype() == rtype);
TORCH_CHECK(dx.is_cuda())
TORCH_CHECK(dx.is_contiguous());
TORCH_CHECK(dx.sizes() == sizes);
}
if (dmask_.has_value()) {
auto dmask = dmask_.value();
TORCH_CHECK(dmask.dtype() == mtype);
@ -263,7 +284,22 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
TORCH_CHECK(rowscale.dtype() == itype);
}
if (colscale_.has_value()) {
auto colscale = colscale_.value();
TORCH_CHECK(colscale.is_cuda())
TORCH_CHECK(colscale.is_contiguous());
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
TORCH_CHECK(colscale.dtype() == wtype);
TORCH_CHECK(x0_.has_value());
auto x0 = x0_.value();
TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(x0.is_contiguous());
TORCH_CHECK(x0.sizes() == sizes);
TORCH_CHECK(x0.dtype() == itype);
}
auto hidden_size = gamma.numel();
@ -282,6 +318,10 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
at::Tensor dcolscale;
if (colscale_.has_value()) {
dcolscale = torch::empty_like(colscale_.value());
}
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
@ -290,31 +330,40 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
launcher(launch_params, true, /*prenorm=*/false);
launcher(launch_params, true);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
at::Tensor dcolscale_part;
if (colscale_.has_value()) {
dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
}
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
params.dx0 = dx0.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
@ -326,137 +375,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false, /*prenorm=*/false);
launcher(launch_params, false);
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &dx, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
const float dropout_p,
const bool has_residual
) {
auto itype = dz.scalar_type();
auto rtype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = itype;
auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8;
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(dx.dtype() == rtype);
TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(dx.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dz.is_contiguous());
TORCH_CHECK(dx.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
TORCH_CHECK(dx.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
if (dmask_.has_value()) {
auto dmask = dmask_.value();
TORCH_CHECK(dmask.dtype() == mtype);
TORCH_CHECK(dmask.is_cuda());
TORCH_CHECK(dmask.is_contiguous());
TORCH_CHECK(dmask.sizes() == sizes);
std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
if (colscale_.has_value()) {
result.push_back(dcolscale);
result.push_back(dcolscale_part);
}
if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
}
auto hidden_size = gamma.numel();
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(gamma.numel() == cols);
auto opts = x.options();
auto dx0 = torch::empty_like(x, opts.dtype(itype));
at::Tensor dx1;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
launcher(launch_params, true, /*prenorm=*/true);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx = dx.data_ptr();
params.dx0 = dx0.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false, /*prenorm=*/true);
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
return result;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -464,5 +390,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA DropoutAddLayerNorm";
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel");
}

View File

@ -7,7 +7,7 @@
namespace layer_norm {
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Is_even_cols>
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) {
@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
Cvec dcolscale_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
Wvec gamma[LDGS];
Wvec colscale[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx);
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
idx += Ktraits::VEC_COLS_PER_LDG;
}
}
@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec dx0;
Rvec dx1;
Ivec x0;
if (Has_colscale) { x0.load_from(params.x0, idx); }
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
@ -140,9 +146,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) {
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
dx0_tmp_res *= params.dropout_scale;
if (Has_colscale) {
dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
} else {
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
}
} else {
dx0.data.elt[jt] = dx0_tmp_res;
if (Has_colscale) {
dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
} else {
dx0.data.elt[jt] = dx0_tmp_res;
}
}
}
if (Has_residual) { dx1.store_to(params.dx1, idx); }
@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if (Is_even_cols || (it < num_valid_ldgs)) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }
idx += Ktraits::VEC_COLS_PER_LDG;
}
}
@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
}
}
compute_t cta_dcolscale_sum[NUM_RES];
if (Has_colscale) {
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dcolscale_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
}
const index_t num_valid_writes
= (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;
for( int jt = 0; jt < NUM_RES; jt++ ) {
if (Is_even_cols || (jt < num_valid_writes)) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
if (Has_colscale) {
*dcolscale_part = cta_dcolscale_sum[jt];
dcolscale_part += Ktraits::THREADS_PER_CTA;
}
}
}
}
}
template<typename Kernel_traits, bool Is_even_cols>
template<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_kernel(BwdParams params)
{
@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params)
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
if (Is_even_cols || col < params.cols) {
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
// index_t idx = row * Kernel_traits::COLS + col;
index_t idx = row * params.cols + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }
}
}
}
void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp;
const int write_col = lane ^ write_row;
@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params)
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params)
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }
// Load beta and gamma transposed
if(read_row < Kernel_traits::ROWS_PER_CTA){
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }
}
// Call reducer on the loaded value(s) and convert.
@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params)
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
if (Has_colscale) {
compute_t cs_i = dcolscale_local.data.elt[it];
cs_i = reducer.allreduce(cs_i, sum);
dcolscale_local.data.elt[it] = cs_i;
}
}
// Leader stores the result at the current column.
if(lane == 0){
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }
}
}
@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params)
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }
}
}
}
@ -364,7 +420,7 @@ template<
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
@ -378,59 +434,64 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
WARPS_N,
BYTES_PER_LDG_MAIN
>;
bool prenorm = launch_params.params.dx != nullptr;
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
bool has_residual = launch_params.params.dx1 != nullptr;
bool has_colscale = launch_params.params.colscale != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(prenorm, PrenormConst, [&] {
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
return;
}
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HasColscaleConst,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
});
});
});
});

View File

@ -16,7 +16,7 @@
namespace layer_norm {
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Is_even_cols>
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) {
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same<input_t, residual_t>::value);
const bool save_x = Has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[];
@ -80,12 +80,14 @@ void ln_fwd_kernel(FwdParams params) {
Wvec gamma[LDGS];
Wvec beta[LDGS];
Wvec colscale[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
idx += VEC_COLS_PER_LDG;
}
}
@ -109,13 +111,9 @@ void ln_fwd_kernel(FwdParams params) {
// the more efficient curand_uniform4.
mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
compute_t x_ij;
if (Has_residual) {
compute_t x1_ij = compute_t(x1.data.elt[jt]);
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
} else {
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
}
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
compute_t x_ij = Has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
@ -130,8 +128,8 @@ void ln_fwd_kernel(FwdParams params) {
const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
// Need to convert to int, otherwise the subtraction will wrap around.
auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
// Need to convert to int, otherwise the subtraction will wrap around.
const index_t valid_partial_vecs_in_warp =
std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
int(THREADS_PER_WARP));
@ -206,45 +204,48 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
BYTES_PER_LDG
>;
bool has_residual = launch_params.params.x1 != nullptr;
bool has_colscale = launch_params.params.colscale != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
return;
}
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
});
});
});
});

View File

@ -38,6 +38,7 @@ template<
typename output_t_,
typename compute_t_,
typename index_t_,
bool Has_colscale,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
@ -69,7 +70,8 @@ struct Kernel_traits_finalize : public Base {
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;
enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;

View File

@ -45,7 +45,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
#define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params, const bool prenorm) { \
const bool configure_params) { \
launch_<WTYPE, \
ITYPE, \
RTYPE, \
@ -57,7 +57,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params, prenorm); \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)

View File

@ -1,11 +1,13 @@
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import torch
from torch.nn import init
import dropout_layer_norm
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon,
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32):
""" Assume that arguments are contiguous
"""
@ -14,133 +16,98 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, ep
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32
x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p,
has_residual):
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
dropout_p, has_residual):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
x0 must not be None if we have colscale.
"""
# dmask is None if dropout_p == 0.0
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape)
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
colscale = colscale.view(-1) if colscale is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p,
has_residual
)
# dx1mat is None if not has_residual
return dx0mat, dx1mat, dgamma, dbeta
if colscale is None:
return dx0mat, dx1mat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dx1mat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
dropout_p, has_residual):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape)
dxmat = dx.view(xmat.shape)
rowscale = rowscale.view(-1) if rowscale is not None else None
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd(
dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
)
return dx0mat, dx1mat, dgamma, dbeta
class DropoutAddLayerNormFN(torch.autograd.Function):
class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
return_dmask=False):
def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32,
prenorm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
rowscale = rowscale.contiguous() if rowscale is not None else None
colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32
)
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None
if not return_dmask:
return zmat.view(x0.shape)
return (zmat.view(x0.shape) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape)))
else:
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask)
return zmat.view(x0.shape), dmask
return ((zmat.view(x0.shape), dmask) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
dx = args[0].contiguous() if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
dcolscale = rest[0] if colscale is not None else None
return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
class DropoutAddLayerNormPrenormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
rowscale = rowscale.contiguous() if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
)
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None
if not return_dmask:
return zmat.view(x0.shape), xmat.view(x0.shape)
else:
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask)
return zmat.view(x0.shape), xmat.view(x0.shape), dmask
@staticmethod
def backward(ctx, dz, dx, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
dx = dx.contiguous() # this happens!
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward(
dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None,
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32,
return_dropout_mask)
if not prenorm:
return DropoutAddLayerNormFN.apply(*args)
else:
return DropoutAddLayerNormPrenormFN.apply(*args)
return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
return_dropout_mask
)
class DropoutAddLayerNorm(torch.nn.Module):

View File

@ -11,6 +11,7 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False])
@ -26,12 +27,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale):
dropout_p, has_residual, has_rowscale, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
# Backward numerical error is high, and this case isn't used
if has_rowscale and not has_residual:
pytest.skip()
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
@ -43,6 +41,12 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
@ -59,6 +63,9 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
rowscale = None
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
@ -71,7 +78,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale,
model.epsilon, rowscale=rowscale, layerscale=colscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
@ -94,6 +101,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@ -139,6 +148,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@ -147,20 +157,17 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('has_residual', [False])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale):
dropout_p, has_residual, has_rowscale, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
# Backward numerical error is high, and this case isn't used
if has_rowscale and not has_residual:
pytest.skip()
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
@ -172,6 +179,12 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
@ -188,6 +201,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
rowscale = None
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
@ -199,7 +215,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, prenorm=True,
model.epsilon, rowscale=rowscale,
layerscale=colscale, prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
@ -225,6 +242,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])