diff --git a/csrc/layer_norm/ln.h b/csrc/layer_norm/ln.h index 25ac64e..1891255 100644 --- a/csrc/layer_norm/ln.h +++ b/csrc/layer_norm/ln.h @@ -40,6 +40,8 @@ struct ParamsBase { , mu(nullptr) , rs(nullptr) , gamma(nullptr) + , rowscale(nullptr) + , colscale(nullptr) , dropout_keep_p(1.f) , dropout_scale(1.f) , workspace(nullptr) @@ -63,6 +65,7 @@ struct ParamsBase { void *rs; void *gamma; void *rowscale; + void *colscale; float inverse_cols; @@ -106,10 +109,12 @@ struct BwdParams : public ParamsBase { , dx(nullptr) , dbeta_part(nullptr) , dgamma_part(nullptr) + , dcolscale_part(nullptr) , dx0(nullptr) , dx1(nullptr) , dbeta(nullptr) , dgamma(nullptr) + , dcolscale(nullptr) { } @@ -121,6 +126,7 @@ struct BwdParams : public ParamsBase { // Workspace for Wgrad pre-reduction. void *dbeta_part; void *dgamma_part; + void *dcolscale_part; // Output: Dgrad. void *dx0; @@ -128,13 +134,14 @@ struct BwdParams : public ParamsBase { // Output: Wgrad. void *dbeta; void *dgamma; + void *dcolscale; }; //////////////////////////////////////////////////////////////////////////////////////////////////// using FwdFunction = std::function&, const bool)>; -using BwdFunction = std::function&, const bool, const bool)>; +using BwdFunction = std::function&, const bool)>; using FunctionKey = uint64_t; using FwdRegistry = std::unordered_map; using BwdRegistry = std::unordered_map; diff --git a/csrc/layer_norm/ln_api.cpp b/csrc/layer_norm/ln_api.cpp index 68962f3..0720b1a 100644 --- a/csrc/layer_norm/ln_api.cpp +++ b/csrc/layer_norm/ln_api.cpp @@ -84,6 +84,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: const at::Tensor &gamma, // hidden_size const at::Tensor &beta, // hidden_size c10::optional &rowscale_, // BxS + c10::optional &colscale_, // BxS const float dropout_p, const float epsilon, c10::optional gen_, @@ -124,7 +125,15 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: TORCH_CHECK(rowscale.is_cuda()) TORCH_CHECK(rowscale.is_contiguous()); TORCH_CHECK(rowscale.sizes() == std::vector{rows}); - TORCH_CHECK(rowscale.scalar_type() == itype); + TORCH_CHECK(rowscale.dtype() == itype); + } + + if (colscale_.has_value()) { + auto colscale = colscale_.value(); + TORCH_CHECK(colscale.is_cuda()) + TORCH_CHECK(colscale.is_contiguous()); + TORCH_CHECK(colscale.sizes() == std::vector{cols}); + TORCH_CHECK(colscale.dtype() == wtype); } TORCH_CHECK(gamma.sizes() == beta.sizes()); @@ -135,7 +144,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: auto opts = x0.options(); - bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype); + bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || (itype != rtype); at::Tensor x; if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } at::Tensor dmask; @@ -153,6 +162,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; + launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -212,12 +222,15 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size + c10::optional &dx_, // BxSxhidden_size const at::Tensor &x, // BxSxhidden_size + c10::optional &x0_, // BxSxhidden_size c10::optional &dmask_, // BxSxhidden_size const at::Tensor &mu, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &gamma, // hidden_size c10::optional &rowscale_, // BxS + c10::optional &colscale_, // BxS const float dropout_p, const bool has_residual ) { @@ -250,6 +263,14 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd auto rows = sizes[0]; auto cols = sizes[1]; + if (dx_.has_value()) { + auto dx = dx_.value(); + TORCH_CHECK(dx.dtype() == rtype); + TORCH_CHECK(dx.is_cuda()) + TORCH_CHECK(dx.is_contiguous()); + TORCH_CHECK(dx.sizes() == sizes); + } + if (dmask_.has_value()) { auto dmask = dmask_.value(); TORCH_CHECK(dmask.dtype() == mtype); @@ -263,7 +284,22 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd TORCH_CHECK(rowscale.is_cuda()) TORCH_CHECK(rowscale.is_contiguous()); TORCH_CHECK(rowscale.sizes() == std::vector{rows}); - TORCH_CHECK(rowscale.scalar_type() == itype); + TORCH_CHECK(rowscale.dtype() == itype); + } + + if (colscale_.has_value()) { + auto colscale = colscale_.value(); + TORCH_CHECK(colscale.is_cuda()) + TORCH_CHECK(colscale.is_contiguous()); + TORCH_CHECK(colscale.sizes() == std::vector{cols}); + TORCH_CHECK(colscale.dtype() == wtype); + + TORCH_CHECK(x0_.has_value()); + auto x0 = x0_.value(); + TORCH_CHECK(x0.is_cuda()) + TORCH_CHECK(x0.is_contiguous()); + TORCH_CHECK(x0.sizes() == sizes); + TORCH_CHECK(x0.dtype() == itype); } auto hidden_size = gamma.numel(); @@ -282,6 +318,10 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } auto dgamma = torch::empty_like(gamma); auto dbeta = torch::empty_like(gamma); + at::Tensor dcolscale; + if (colscale_.has_value()) { + dcolscale = torch::empty_like(colscale_.value()); + } layer_norm::LaunchParams launch_params; launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); @@ -290,31 +330,40 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; + launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); - launcher(launch_params, true, /*prenorm=*/false); + launcher(launch_params, true); auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + at::Tensor dcolscale_part; + if (colscale_.has_value()) { + dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + } at::Tensor workspace, barrier; layer_norm::BwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; params.x = x.data_ptr(); + params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr; params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr; params.mu = mu.data_ptr(); params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.dz = dz.data_ptr(); + params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr; params.dx0 = dx0.data_ptr(); params.dbeta = dbeta.data_ptr(); params.dgamma = dgamma.data_ptr(); + params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr; params.dbeta_part = dbeta_part.data_ptr(); params.dgamma_part = dgamma_part.data_ptr(); + params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr; params.dropout_scale = 1.f / (1.f - dropout_p); params.inverse_cols = 1.f / float(params.cols); @@ -326,137 +375,14 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd params.barrier = barrier.data_ptr(); } - launcher(launch_params, false, /*prenorm=*/false); + launcher(launch_params, false); - return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // BxSxhidden_size - const at::Tensor &dx, // BxSxhidden_size - const at::Tensor &x, // BxSxhidden_size - c10::optional &dmask_, // BxSxhidden_size - const at::Tensor &mu, // BxS, FP32! - const at::Tensor &rsigma, // BxS, FP32! - const at::Tensor &gamma, // hidden_size - c10::optional &rowscale_, // BxS - const float dropout_p, - const bool has_residual -) { - - auto itype = dz.scalar_type(); - auto rtype = x.scalar_type(); - auto wtype = gamma.scalar_type(); - auto otype = itype; - auto ctype = torch::kFloat32; - auto mtype = torch::kUInt8; - - if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); } - - TORCH_CHECK(dz.dtype() == otype); - TORCH_CHECK(dx.dtype() == rtype); - TORCH_CHECK(mu.dtype() == ctype); - TORCH_CHECK(rsigma.dtype() == ctype); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(dx.is_cuda()); - TORCH_CHECK(mu.is_cuda()); - TORCH_CHECK(rsigma.is_cuda()); - TORCH_CHECK(gamma.is_cuda()); - - TORCH_CHECK(x.is_contiguous()); - TORCH_CHECK(dz.is_contiguous()); - TORCH_CHECK(dx.is_contiguous()); - - auto sizes = x.sizes(); - TORCH_CHECK(sizes.size() == 2); - TORCH_CHECK(dz.sizes() == sizes); - TORCH_CHECK(dx.sizes() == sizes); - auto rows = sizes[0]; - auto cols = sizes[1]; - - if (dmask_.has_value()) { - auto dmask = dmask_.value(); - TORCH_CHECK(dmask.dtype() == mtype); - TORCH_CHECK(dmask.is_cuda()); - TORCH_CHECK(dmask.is_contiguous()); - TORCH_CHECK(dmask.sizes() == sizes); + std::vector result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part }; + if (colscale_.has_value()) { + result.push_back(dcolscale); + result.push_back(dcolscale_part); } - - if (rowscale_.has_value()) { - auto rowscale = rowscale_.value(); - TORCH_CHECK(rowscale.is_cuda()) - TORCH_CHECK(rowscale.is_contiguous()); - TORCH_CHECK(rowscale.sizes() == std::vector{rows}); - TORCH_CHECK(rowscale.scalar_type() == itype); - } - - auto hidden_size = gamma.numel(); - TORCH_CHECK(hidden_size == cols); - TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144)); - - TORCH_CHECK(mu.numel() == rows); - TORCH_CHECK(mu.sizes() == rsigma.sizes()); - - TORCH_CHECK(gamma.numel() == cols); - - auto opts = x.options(); - - auto dx0 = torch::empty_like(x, opts.dtype(itype)); - at::Tensor dx1; - if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } - auto dgamma = torch::empty_like(gamma); - auto dbeta = torch::empty_like(gamma); - - layer_norm::LaunchParams launch_params; - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - launch_params.props = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(dropout_p < 1.f); - launch_params.params.dropout_keep_p = 1.f - dropout_p; - launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; - launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); - auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); - - launcher(launch_params, true, /*prenorm=*/true); - - auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); - auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); - at::Tensor workspace, barrier; - - layer_norm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data_ptr(); - params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr; - params.mu = mu.data_ptr(); - params.rs = rsigma.data_ptr(); - params.gamma = gamma.data_ptr(); - params.dz = dz.data_ptr(); - params.dx = dx.data_ptr(); - params.dx0 = dx0.data_ptr(); - params.dbeta = dbeta.data_ptr(); - params.dgamma = dgamma.data_ptr(); - params.dbeta_part = dbeta_part.data_ptr(); - params.dgamma_part = dgamma_part.data_ptr(); - params.dropout_scale = 1.f / (1.f - dropout_p); - params.inverse_cols = 1.f / float(params.cols); - - if( launch_params.barrier_size > 0 ) { - // TODO Any way to avoid this? - barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } - - launcher(launch_params, false, /*prenorm=*/true); - - return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part }; + return result; } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -464,5 +390,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "CUDA DropoutAddLayerNorm"; m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel"); m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel"); - m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel"); } diff --git a/csrc/layer_norm/ln_bwd_kernels.cuh b/csrc/layer_norm/ln_bwd_kernels.cuh index d567ce8..7d3193d 100644 --- a/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/csrc/layer_norm/ln_bwd_kernels.cuh @@ -7,7 +7,7 @@ namespace layer_norm { -template +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { @@ -53,9 +53,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { Cvec dzy_sum[LDGS]; Cvec dz_sum[LDGS]; + Cvec dcolscale_sum[LDGS]; memset(dzy_sum, 0, sizeof(dzy_sum)); memset(dz_sum, 0, sizeof(dz_sum)); + if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); } compute_t * smem_wgrad = reinterpret_cast(smem_); char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; @@ -68,11 +70,13 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG; Wvec gamma[LDGS]; + Wvec colscale[LDGS]; index_t idx = c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { gamma[it].load_from(params.gamma, idx); + if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } } @@ -131,6 +135,8 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { if (Is_even_cols || (it < num_valid_ldgs)) { Ivec dx0; Rvec dx1; + Ivec x0; + if (Has_colscale) { x0.load_from(params.x0, idx); } #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { compute_t dy_tmp = dy[it * NUM_ELTS + jt]; @@ -140,9 +146,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; } compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; if (Is_dropout) { - dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f; + dx0_tmp_res *= params.dropout_scale; + if (Has_colscale) { + dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f; + dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f; + } else { + dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f; + } } else { - dx0.data.elt[jt] = dx0_tmp_res; + if (Has_colscale) { + dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]); + dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]); + } else { + dx0.data.elt[jt] = dx0_tmp_res; + } } } if (Has_residual) { dx1.store_to(params.dx1, idx); } @@ -160,6 +177,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { if (Is_even_cols || (it < num_valid_ldgs)) { dz_sum[it].store_to(params.dbeta_part, idx); dzy_sum[it].store_to(params.dgamma_part, idx); + if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } } @@ -203,23 +221,46 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { } } + compute_t cta_dcolscale_sum[NUM_RES]; + if (Has_colscale) { + __syncthreads(); + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dcolscale_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + } + const index_t num_valid_writes = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA; compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * params.cols + tidx; compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * params.cols + tidx; + compute_t *dcolscale_part = Has_colscale ? static_cast(params.dcolscale_part) + bidm * params.cols + tidx : nullptr; for( int jt = 0; jt < NUM_RES; jt++ ) { if (Is_even_cols || (jt < num_valid_writes)) { *dgamma_part = cta_dzy_sum[jt]; dgamma_part += Ktraits::THREADS_PER_CTA; *dbeta_part = cta_dz_sum[jt]; dbeta_part += Ktraits::THREADS_PER_CTA; + if (Has_colscale) { + *dcolscale_part = cta_dcolscale_sum[jt]; + dcolscale_part += Ktraits::THREADS_PER_CTA; + } } } } } -template +template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(BwdParams params) { @@ -250,26 +291,29 @@ void ln_bwd_finalize_kernel(BwdParams params) constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { // Each thread sums over NUM_ELT columns. - Vec dbeta_local, dgamma_local; + Vec dbeta_local, dgamma_local, dcolscale_local; memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dbeta_local, 0, sizeof(dbeta_local)); + if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); } if (Is_even_cols || col < params.cols) { for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { - // index_t idx = row * Kernel_traits::COLS + col; index_t idx = row * params.cols + col; - Vec dbeta_part, dgamma_part; + Vec dbeta_part, dgamma_part, dcolscale_part; dbeta_part.load_from(params.dbeta_part, idx); dgamma_part.load_from(params.dgamma_part, idx); + if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); } #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; + if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; } } } } void * smem_gamma = smem_; void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; const int write_row = warp; const int write_col = lane ^ write_row; @@ -277,12 +321,14 @@ void ln_bwd_finalize_kernel(BwdParams params) dgamma_local.store_to(smem_gamma, write_idx); dbeta_local.store_to(smem_beta, write_idx); + if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); } __syncthreads(); // It would be probably safe to reuse the first row of smem_beta and smem_gamma - void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT]; // More than one iter iff ROWS_PER_CTA < 32. @@ -293,11 +339,13 @@ void ln_bwd_finalize_kernel(BwdParams params) memset(&dbeta_local, 0, sizeof(dbeta_local)); memset(&dgamma_local, 0, sizeof(dgamma_local)); + if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); } // Load beta and gamma transposed if(read_row < Kernel_traits::ROWS_PER_CTA){ dbeta_local.load_from(smem_beta, read_idx); dgamma_local.load_from(smem_gamma, read_idx); + if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); } } // Call reducer on the loaded value(s) and convert. @@ -310,12 +358,18 @@ void ln_bwd_finalize_kernel(BwdParams params) dgamma_local.data.elt[it] = g_i; dbeta_local.data.elt[it] = b_i; + if (Has_colscale) { + compute_t cs_i = dcolscale_local.data.elt[it]; + cs_i = reducer.allreduce(cs_i, sum); + dcolscale_local.data.elt[it] = cs_i; + } } // Leader stores the result at the current column. if(lane == 0){ dgamma_local.store_to(smem_gamma_out, w); dbeta_local.store_to(smem_beta_out, w); + if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); } } } @@ -329,19 +383,21 @@ void ln_bwd_finalize_kernel(BwdParams params) using src_t = typename TypeToVec2::Type; using dst_t = typename TypeToVec2::Type; - Vec dbeta_vec2, dgamma_vec2; - Vec dbeta_out2, dgamma_out2; + Vec dbeta_vec2, dgamma_vec2, dcolscale_vec2; + Vec dbeta_out2, dgamma_out2, dcolscale_out2; dgamma_vec2.load_from(smem_gamma_out, lane); dbeta_vec2.load_from(smem_beta_out, lane); + if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); } #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); + if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter::convert(dcolscale_vec2.data.elt[it]); } } dgamma_out2.store_to(params.dgamma, col_out); dbeta_out2.store_to(params.dbeta, col_out); - + if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); } } } } @@ -364,7 +420,7 @@ template< int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL > -void launch_(LaunchParams &launch_params, const bool configure_params, const bool prenorm){ +void launch_(LaunchParams &launch_params, const bool configure_params){ using Kernel_traits = Kernel_traits &launch_params, const bool configure_params WARPS_N, BYTES_PER_LDG_MAIN >; + bool prenorm = launch_params.params.dx != nullptr; bool is_dropout = launch_params.params.dropout_keep_p < 1.f; bool has_residual = launch_params.params.dx1 != nullptr; + bool has_colscale = launch_params.params.colscale != nullptr; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; BOOL_SWITCH(prenorm, PrenormConst, [&] { BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { BOOL_SWITCH(has_residual, HasResidualConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_bwd_kernel; - if( configure_params ) { - int ctas_per_sm; - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) - * 2; + BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_bwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::reduce_t) + * 2; + } + return; } - return; - } - if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; + if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); - } + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); + } - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; + using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; - kernel_f<<>>(launch_params.params); + auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; + kernel_f<<>>(launch_params.params); + }); }); }); }); diff --git a/csrc/layer_norm/ln_fwd_kernels.cuh b/csrc/layer_norm/ln_fwd_kernels.cuh index 4f90cb3..0c09a3b 100644 --- a/csrc/layer_norm/ln_fwd_kernels.cuh +++ b/csrc/layer_norm/ln_fwd_kernels.cuh @@ -16,7 +16,7 @@ namespace layer_norm { -template +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(FwdParams params) { @@ -46,7 +46,7 @@ void ln_fwd_kernel(FwdParams params) { using Stats = typename Ktraits::Stats; using stats_t = typename Stats::stats_t; - constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same::value); + const bool save_x = Has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || !(std::is_same::value); extern __shared__ char smem_[]; @@ -80,12 +80,14 @@ void ln_fwd_kernel(FwdParams params) { Wvec gamma[LDGS]; Wvec beta[LDGS]; + Wvec colscale[LDGS]; index_t idx = c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { gamma[it].load_from(params.gamma, idx); beta[it].load_from(params.beta, idx); + if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } idx += VEC_COLS_PER_LDG; } } @@ -109,13 +111,9 @@ void ln_fwd_kernel(FwdParams params) { // the more efficient curand_uniform4. mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; - compute_t x_ij; - if (Has_residual) { - compute_t x1_ij = compute_t(x1.data.elt[jt]); - x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij; - } else { - x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f; - } + x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; + if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } + compute_t x_ij = Has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij; if (save_x) { x.data.elt[jt] = x_ij; } xf[it * NUM_ELTS + jt] = x_ij; if (Is_dropout) { dmask.data.elt[jt] = keep; } @@ -130,8 +128,8 @@ void ln_fwd_kernel(FwdParams params) { const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; - // Need to convert to int, otherwise the subtraction will wrap around. auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { + // Need to convert to int, otherwise the subtraction will wrap around. const index_t valid_partial_vecs_in_warp = std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), int(THREADS_PER_WARP)); @@ -206,45 +204,48 @@ void launch_(LaunchParams &launch_params, const bool configure_params BYTES_PER_LDG >; bool has_residual = launch_params.params.x1 != nullptr; + bool has_colscale = launch_params.params.colscale != nullptr; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { BOOL_SWITCH(has_residual, HasResidualConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_fwd_kernel; - if( configure_params ) { - int ctas_per_sm; - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; - launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) - * 2; + BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_fwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; + launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::Stats::stats_t) + * 2; + } + return; } - return; - } - if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; + if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); - } + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); + } + }); }); }); }); diff --git a/csrc/layer_norm/ln_kernel_traits.h b/csrc/layer_norm/ln_kernel_traits.h index aa855b8..77de6bf 100644 --- a/csrc/layer_norm/ln_kernel_traits.h +++ b/csrc/layer_norm/ln_kernel_traits.h @@ -38,6 +38,7 @@ template< typename output_t_, typename compute_t_, typename index_t_, + bool Has_colscale, uint32_t THREADS_PER_CTA_, uint32_t BYTES_PER_LDG_, typename Base = Kernel_traits_base; diff --git a/csrc/layer_norm/ln_utils.cuh b/csrc/layer_norm/ln_utils.cuh index 1047cef..2e089de 100644 --- a/csrc/layer_norm/ln_utils.cuh +++ b/csrc/layer_norm/ln_utils.cuh @@ -45,7 +45,7 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) { #define REGISTER_BWD_LAUNCHER( \ HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params, const bool prenorm) { \ + const bool configure_params) { \ launch_(launch_params, configure_params, prenorm); \ + BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \ } \ static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) diff --git a/flash_attn/ops/layer_norm.py b/flash_attn/ops/layer_norm.py index 5f37c7d..a70088b 100644 --- a/flash_attn/ops/layer_norm.py +++ b/flash_attn/ops/layer_norm.py @@ -1,11 +1,13 @@ +# Copyright (c) 2022, Tri Dao. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py + import torch from torch.nn import init import dropout_layer_norm -def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon, +def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32): """ Assume that arguments are contiguous """ @@ -14,133 +16,98 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, ep x1mat = x1.view((-1, hidden_size)) if x1 is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32 + x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32 ) # dmask is None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma -def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, - has_residual): +def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, + dropout_p, has_residual): """ Assume that arguments are contiguous + dx == None means that it was a post-norm architecture + (x = drop(x0) + x1 was not returned in the fwd). + x0 must not be None if we have colscale. """ - # dmask is None if dropout_p == 0.0 hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) dzmat = dz.view(xmat.shape) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None - dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual + colscale = colscale.view(-1) if colscale is not None else None + if colscale is not None: + assert x0 is not None, 'x0 is required to compute the gradient of colscale' + dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, + has_residual ) # dx1mat is None if not has_residual - return dx0mat, dx1mat, dgamma, dbeta + if colscale is None: + return dx0mat, dx1mat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dx1mat, dgamma, dbeta, dcolscale -def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale, - dropout_p, has_residual): - """ Assume that arguments are contiguous - """ - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dzmat = dz.view(xmat.shape) - dxmat = dx.view(xmat.shape) - rowscale = rowscale.view(-1) if rowscale is not None else None - dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd( - dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual - ) - return dx0mat, dx1mat, dgamma, dbeta - - -class DropoutAddLayerNormFN(torch.autograd.Function): +class DropoutAddLayerNormFn(torch.autograd.Function): @staticmethod - def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32, - return_dmask=False): + def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32, + prenorm=False, return_dmask=False): x0 = x0.contiguous() x1 = x1.contiguous() if x1 is not None else None gamma = gamma.contiguous() beta = beta.contiguous() rowscale = rowscale.contiguous() if rowscale is not None else None + colscale = colscale.contiguous() if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32 + x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32 ) - ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale) + ctx.prenorm = prenorm ctx.dropout_p = dropout_p ctx.has_residual = x1 is not None if not return_dmask: - return zmat.view(x0.shape) + return (zmat.view(x0.shape) if not prenorm + else (zmat.view(x0.shape), xmat.view(x0.shape))) else: dmask = (dmask.view(x0.shape) if dropout_p > 0. else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) ctx.mark_non_differentiable(dmask) - return zmat.view(x0.shape), dmask + return ((zmat.view(x0.shape), dmask) if not prenorm + else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)) @staticmethod def backward(ctx, dz, *args): # assert dz.is_contiguous() dz = dz.contiguous() # this happens! - x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors + dx = args[0].contiguous() if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors + # x0 is None if colscale is None dropout_p = ctx.dropout_p has_residual = ctx.has_residual - dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward( - dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual + dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( + dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual ) dx0 = dx0mat.view(x.shape) dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - return dx0, dx1, dgamma, dbeta, None, None, None, None, None + dcolscale = rest[0] if colscale is not None else None + return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None -class DropoutAddLayerNormPrenormFN(torch.autograd.Function): - @staticmethod - def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32, - return_dmask=False): - x0 = x0.contiguous() - x1 = x1.contiguous() if x1 is not None else None - gamma = gamma.contiguous() - beta = beta.contiguous() - rowscale = rowscale.contiguous() if rowscale is not None else None - zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( - x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32 - ) - ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale) - ctx.dropout_p = dropout_p - ctx.has_residual = x1 is not None - if not return_dmask: - return zmat.view(x0.shape), xmat.view(x0.shape) - else: - dmask = (dmask.view(x0.shape) if dropout_p > 0. - else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) - ctx.mark_non_differentiable(dmask) - return zmat.view(x0.shape), xmat.view(x0.shape), dmask - - @staticmethod - def backward(ctx, dz, dx, *args): - # assert dz.is_contiguous() - dz = dz.contiguous() # this happens! - dx = dx.contiguous() # this happens! - x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors - dropout_p = ctx.dropout_p - has_residual = ctx.has_residual - dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward( - dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual - ) - dx0 = dx0mat.view(x.shape) - dx1 = dx1mat.view(x.shape) if dx1mat is not None else None - return dx0, dx1, dgamma, dbeta, None, None, None, None, None - - -def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, +def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False): """residual_in_fp32 only has an effect if x1 is None. Otherwise residual dtype is x1.dtype. """ - args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32, - return_dropout_mask) - if not prenorm: - return DropoutAddLayerNormFN.apply(*args) - else: - return DropoutAddLayerNormPrenormFN.apply(*args) + return DropoutAddLayerNormFn.apply( + x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, + return_dropout_mask + ) class DropoutAddLayerNorm(torch.nn.Module): diff --git a/tests/ops/test_dropout_layer_norm.py b/tests/ops/test_dropout_layer_norm.py index fa0c06c..ebc0da4 100644 --- a/tests/ops/test_dropout_layer_norm.py +++ b/tests/ops/test_dropout_layer_norm.py @@ -11,6 +11,7 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 +@pytest.mark.parametrize('has_colscale', [True, False]) @pytest.mark.parametrize('has_rowscale', [True, False]) # @pytest.mark.parametrize('has_rowscale', [True]) @pytest.mark.parametrize('has_residual', [True, False]) @@ -26,12 +27,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, - dropout_p, has_residual, has_rowscale): + dropout_p, has_residual, has_rowscale, has_colscale): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported - # Backward numerical error is high, and this case isn't used - if has_rowscale and not has_residual: - pytest.skip() device = 'cuda' # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4) @@ -43,6 +41,12 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w requires_grad=True) x0 = x0_pt.detach().clone().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() + if has_colscale: + colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) + colscale_pt = colscale.detach().clone().requires_grad_() + colscale_ref = colscale.detach().clone().float().requires_grad_() + else: + colscale = None if has_residual: x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) x1 = x1_pt.detach().clone().requires_grad_() @@ -59,6 +63,9 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w rowscale = None x0_scaled_pt = x0_pt x0_scaled_ref = x0_ref + if has_colscale: + x0_scaled_pt = x0_scaled_pt * colscale_pt + x0_scaled_ref = x0_scaled_ref * colscale_ref model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.bias) @@ -71,7 +78,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, - model.epsilon, rowscale=rowscale, + model.epsilon, rowscale=rowscale, layerscale=colscale, residual_in_fp32=residual_in_fp32, return_dropout_mask=True) assert out.dtype == input_dtype print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') @@ -94,6 +101,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5 assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5 + if has_colscale: + assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) @@ -139,6 +148,7 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 +@pytest.mark.parametrize('has_colscale', [True, False]) @pytest.mark.parametrize('has_rowscale', [True, False]) @pytest.mark.parametrize('has_residual', [True, False]) @pytest.mark.parametrize('dropout_p', [0.37, 0.0]) @@ -147,20 +157,17 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) +# @pytest.mark.parametrize('has_colscale', [True]) # @pytest.mark.parametrize('has_rowscale', [False]) -# @pytest.mark.parametrize('has_residual', [True]) +# @pytest.mark.parametrize('has_residual', [False]) # @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) -# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, - dropout_p, has_residual, has_rowscale): + dropout_p, has_residual, has_rowscale, has_colscale): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported - # Backward numerical error is high, and this case isn't used - if has_rowscale and not has_residual: - pytest.skip() device = 'cuda' # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 2e-4) @@ -172,6 +179,12 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ requires_grad=True) x0 = x0_pt.detach().clone().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() + if has_colscale: + colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) + colscale_pt = colscale.detach().clone().requires_grad_() + colscale_ref = colscale.detach().clone().float().requires_grad_() + else: + colscale = None if has_residual: x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) x1 = x1_pt.detach().clone().requires_grad_() @@ -188,6 +201,9 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ rowscale = None x0_scaled_pt = x0_pt x0_scaled_ref = x0_ref + if has_colscale: + x0_scaled_pt = x0_scaled_pt * colscale_pt + x0_scaled_ref = x0_scaled_ref * colscale_ref model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, @@ -199,7 +215,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, - model.epsilon, rowscale=rowscale, prenorm=True, + model.epsilon, rowscale=rowscale, + layerscale=colscale, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True) print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') @@ -225,6 +242,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 + if has_colscale: + assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])