From 5db330519a9fe8037ba5eb2b67b9dd1848189342 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 12 Dec 2022 22:16:14 -0800 Subject: [PATCH] [LayerNorm] Support taking subset of input or subset of output --- csrc/layer_norm/ln.h | 3 + csrc/layer_norm/ln_api.cpp | 86 ++++++-- csrc/layer_norm/ln_bwd_kernels.cuh | 233 ++++++++++++---------- csrc/layer_norm/ln_fwd_kernels.cuh | 82 ++++---- flash_attn/ops/layer_norm.py | 123 +++++++++++- tests/ops/test_dropout_layer_norm.py | 284 +++++++++++++++++++++++++-- 6 files changed, 643 insertions(+), 168 deletions(-) diff --git a/csrc/layer_norm/ln.h b/csrc/layer_norm/ln.h index 1891255..4ecca02 100644 --- a/csrc/layer_norm/ln.h +++ b/csrc/layer_norm/ln.h @@ -66,11 +66,14 @@ struct ParamsBase { void *gamma; void *rowscale; void *colscale; + void *x0_subset; + void *z_subset; float inverse_cols; float dropout_keep_p; float dropout_scale; + float rowscale_const; // Multi-CTA workspace in gmem. void *workspace; diff --git a/csrc/layer_norm/ln_api.cpp b/csrc/layer_norm/ln_api.cpp index 0720b1a..8ab3fa0 100644 --- a/csrc/layer_norm/ln_api.cpp +++ b/csrc/layer_norm/ln_api.cpp @@ -84,9 +84,13 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: const at::Tensor &gamma, // hidden_size const at::Tensor &beta, // hidden_size c10::optional &rowscale_, // BxS - c10::optional &colscale_, // BxS + c10::optional &colscale_, // hidden_size + c10::optional &x0_subset_, // BxS + c10::optional &z_subset_, // BxS const float dropout_p, const float epsilon, + const float rowscale_const, + const int64_t z_numrows, c10::optional gen_, bool residual_in_fp32 ) { @@ -99,14 +103,19 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: auto ctype = torch::kFloat32; auto mtype = torch::kUInt8; - TORCH_CHECK(beta.scalar_type() == wtype); + TORCH_CHECK(beta.dtype() == wtype); TORCH_CHECK(x0.is_cuda()) TORCH_CHECK(gamma.is_cuda()) TORCH_CHECK(beta.is_cuda()) TORCH_CHECK(x0.is_contiguous()); - auto sizes = x0.sizes(); + // c10::IntArrayRef does not own the storage, so we need to construct a vector. + // Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because + // blah is then deallocated. + std::vector sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)}; + auto sizes = c10::IntArrayRef(sizes_vec); + TORCH_CHECK(x0.dim() == 2); TORCH_CHECK(sizes.size() == 2); const int rows = sizes[0]; @@ -124,7 +133,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: auto rowscale = rowscale_.value(); TORCH_CHECK(rowscale.is_cuda()) TORCH_CHECK(rowscale.is_contiguous()); - TORCH_CHECK(rowscale.sizes() == std::vector{rows}); + TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows}); TORCH_CHECK(rowscale.dtype() == itype); } @@ -132,10 +141,25 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: auto colscale = colscale_.value(); TORCH_CHECK(colscale.is_cuda()) TORCH_CHECK(colscale.is_contiguous()); - TORCH_CHECK(colscale.sizes() == std::vector{cols}); + TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols}); TORCH_CHECK(colscale.dtype() == wtype); } + if (x0_subset_.has_value()) { + auto x0_subset = x0_subset_.value(); + TORCH_CHECK(x0_subset.is_cuda()) + TORCH_CHECK(x0_subset.is_contiguous()); + TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(x0_subset.dtype() == torch::kInt32); + + TORCH_CHECK(z_subset_.has_value()); + auto z_subset = z_subset_.value(); + TORCH_CHECK(z_subset.is_cuda()); + TORCH_CHECK(z_subset.is_contiguous()); + TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(z_subset.dtype() == torch::kInt32); + } + TORCH_CHECK(gamma.sizes() == beta.sizes()); TORCH_CHECK(hidden_size == cols); TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144)); @@ -144,12 +168,12 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: auto opts = x0.options(); - bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || (itype != rtype); + bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype); at::Tensor x; if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } at::Tensor dmask; - if (dropout_p > 0.f) { dmask = torch::empty(sizes, opts.dtype(mtype)); }; - auto z = torch::empty(sizes, opts.dtype(otype)); + if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); }; + auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype)); auto mu = torch::empty({ rows }, opts.dtype(ctype)); auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); @@ -163,6 +187,8 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; + launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; + launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -192,6 +218,7 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: params.epsilon = epsilon; params.dropout_scale = 1.f / (1.f - dropout_p); params.inverse_cols = 1.f / float(params.cols); + params.rowscale_const = rowscale_const; if (dropout_p > 0.f) { // number of times random will be generated per thread, to offset philox counter in thc random @@ -230,8 +257,12 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &gamma, // hidden_size c10::optional &rowscale_, // BxS - c10::optional &colscale_, // BxS + c10::optional &colscale_, // hidden_size + c10::optional &x0_subset_, // BxS + c10::optional &z_subset_, // BxS const float dropout_p, + const float rowscale_const, + const int64_t x0_numrows, const bool has_residual ) { @@ -259,9 +290,16 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd auto sizes = x.sizes(); TORCH_CHECK(sizes.size() == 2); - TORCH_CHECK(dz.sizes() == sizes); auto rows = sizes[0]; auto cols = sizes[1]; + TORCH_CHECK(dz.dim() == 2); + TORCH_CHECK(dz.size(1) == cols); + + // c10::IntArrayRef does not own the storage, so we need to construct a vector. + // Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because + // blah is then deallocated. + std::vector x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols}; + auto x0_sizes = c10::IntArrayRef(x0_sizes_vec); if (dx_.has_value()) { auto dx = dx_.value(); @@ -276,14 +314,14 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd TORCH_CHECK(dmask.dtype() == mtype); TORCH_CHECK(dmask.is_cuda()); TORCH_CHECK(dmask.is_contiguous()); - TORCH_CHECK(dmask.sizes() == sizes); + TORCH_CHECK(dmask.sizes() == x0_sizes); } if (rowscale_.has_value()) { auto rowscale = rowscale_.value(); TORCH_CHECK(rowscale.is_cuda()) TORCH_CHECK(rowscale.is_contiguous()); - TORCH_CHECK(rowscale.sizes() == std::vector{rows}); + TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows}); TORCH_CHECK(rowscale.dtype() == itype); } @@ -291,17 +329,32 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd auto colscale = colscale_.value(); TORCH_CHECK(colscale.is_cuda()) TORCH_CHECK(colscale.is_contiguous()); - TORCH_CHECK(colscale.sizes() == std::vector{cols}); + TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols}); TORCH_CHECK(colscale.dtype() == wtype); TORCH_CHECK(x0_.has_value()); auto x0 = x0_.value(); TORCH_CHECK(x0.is_cuda()) TORCH_CHECK(x0.is_contiguous()); - TORCH_CHECK(x0.sizes() == sizes); + TORCH_CHECK(x0.sizes() == x0_sizes); TORCH_CHECK(x0.dtype() == itype); } + if (x0_subset_.has_value()) { + auto x0_subset = x0_subset_.value(); + TORCH_CHECK(x0_subset.is_cuda()) + TORCH_CHECK(x0_subset.is_contiguous()); + TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(x0_subset.dtype() == torch::kInt32); + + TORCH_CHECK(z_subset_.has_value()); + auto z_subset = z_subset_.value(); + TORCH_CHECK(z_subset.is_cuda()); + TORCH_CHECK(z_subset.is_contiguous()); + TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows}); + TORCH_CHECK(z_subset.dtype() == torch::kInt32); + } + auto hidden_size = gamma.numel(); TORCH_CHECK(hidden_size == cols); TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144)); @@ -313,7 +366,7 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd auto opts = x.options(); - auto dx0 = torch::empty_like(x, opts.dtype(itype)); + auto dx0 = torch::empty(x0_sizes, opts.dtype(itype)); at::Tensor dx1; if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } auto dgamma = torch::empty_like(gamma); @@ -331,6 +384,8 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; + launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; + launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); @@ -366,6 +421,7 @@ std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr; params.dropout_scale = 1.f / (1.f - dropout_p); params.inverse_cols = 1.f / float(params.cols); + params.rowscale_const = rowscale_const; if( launch_params.barrier_size > 0 ) { // TODO Any way to avoid this? diff --git a/csrc/layer_norm/ln_bwd_kernels.cuh b/csrc/layer_norm/ln_bwd_kernels.cuh index 7d3193d..95be982 100644 --- a/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/csrc/layer_norm/ln_bwd_kernels.cuh @@ -7,7 +7,7 @@ namespace layer_norm { -template +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { @@ -37,6 +37,9 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { extern __shared__ char smem_[]; + const bool has_residual = params.dx1 != nullptr; + const bool prenorm = params.dx != nullptr; + const index_t tidx = threadIdx.x; const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW; @@ -51,6 +54,10 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + const input_t *rowscale = static_cast(params.rowscale); + const index_t *x0_subset = static_cast(params.x0_subset); + const index_t *z_subset = static_cast(params.z_subset); + Cvec dzy_sum[LDGS]; Cvec dz_sum[LDGS]; Cvec dcolscale_sum[LDGS]; @@ -87,40 +94,62 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { const compute_t mu_r = static_cast(params.mu)[row]; const compute_t rs_r = static_cast(params.rs)[row]; - const compute_t rowscale_val = - params.rowscale == nullptr ? 1.0f : compute_t(static_cast(params.rowscale)[row]); + const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; + const int row_z = !Has_subset ? row + 1 : z_subset[row]; + const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; + const bool load_dz = !Has_subset || row_z > 0; + const bool save_dx0 = !Has_subset || row_x0 > 0; Mvec dmask[LDGS]; Rvec dx[LDGS]; compute_t dy[LDGS * NUM_ELTS]; compute_t y[LDGS * NUM_ELTS]; compute_t mdy_local = 0.f; compute_t mdyy_local = 0.f; - index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Rvec x; - Ovec dz; - dz.load_from(params.dz, idx); - if (Prenorm) { dx[it].load_from(params.dx, idx); } - x.load_from(params.x, idx); - if (Is_dropout) { dmask[it].load_from(params.dmask, idx); } - idx += Ktraits::VEC_COLS_PER_LDG; - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_tmp = x.data.elt[jt]; - compute_t y_tmp = rs_r * (x_tmp - mu_r); - compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]); - compute_t dz_tmp = dz.data.elt[jt]; + // If dz is not loaded, then dy should be 0 and we don't care about the value of y. + if (load_dz) { + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Rvec x; + Ovec dz; + dz.load_from(params.dz, !Has_subset ? idx_x : idx_z); + if (prenorm) { dx[it].load_from(params.dx, idx_x); } + x.load_from(params.x, idx_x); + if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); } + idx_x += Ktraits::VEC_COLS_PER_LDG; + idx_z += Ktraits::VEC_COLS_PER_LDG; + idx_x0 += Ktraits::VEC_COLS_PER_LDG; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t x_tmp = x.data.elt[jt]; + compute_t y_tmp = rs_r * (x_tmp - mu_r); + compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]); + compute_t dz_tmp = dz.data.elt[jt]; - mdy_local += dy_tmp; - mdyy_local += dy_tmp * y_tmp; + mdy_local += dy_tmp; + mdyy_local += dy_tmp * y_tmp; - dy[it * NUM_ELTS + jt] = dy_tmp; - y[it * NUM_ELTS + jt] = y_tmp; + dy[it * NUM_ELTS + jt] = dy_tmp; + y[it * NUM_ELTS + jt] = y_tmp; - dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; - dz_sum[it].data.elt[jt] += dz_tmp; + dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; + dz_sum[it].data.elt[jt] += dz_tmp; + } + } + } + } else { + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + if (prenorm) { dx[it].load_from(params.dx, idx_x); } + if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); } + idx_x += Ktraits::VEC_COLS_PER_LDG; + idx_x0 += Ktraits::VEC_COLS_PER_LDG; } } } @@ -129,42 +158,51 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { mdy_local = layer_norm::Get<0>::of(result) * params.inverse_cols; mdyy_local = layer_norm::Get<1>::of(result) * params.inverse_cols; - idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ivec dx0; Rvec dx1; Ivec x0; - if (Has_colscale) { x0.load_from(params.x0, idx); } + if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t dy_tmp = dy[it * NUM_ELTS + jt]; - compute_t y_tmp = y[it * NUM_ELTS + jt]; - compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); - compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; - if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; } - compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; - if (Is_dropout) { - dx0_tmp_res *= params.dropout_scale; - if (Has_colscale) { - dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f; - dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f; - } else { - dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f; - } + compute_t dx_tmp_res; + if (load_dz) { + compute_t dy_tmp = dy[it * NUM_ELTS + jt]; + compute_t y_tmp = y[it * NUM_ELTS + jt]; + compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); + dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; } else { - if (Has_colscale) { - dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]); - dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]); + dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f; + } + if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; } + if (save_dx0) { + compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; + if (Is_dropout) { + dx0_tmp_res *= params.dropout_scale; + if (Has_colscale) { + dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f; + dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f; + } else { + dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f; + } } else { - dx0.data.elt[jt] = dx0_tmp_res; + if (Has_colscale) { + dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]); + dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]); + } else { + dx0.data.elt[jt] = dx0_tmp_res; + } } } } - if (Has_residual) { dx1.store_to(params.dx1, idx); } - dx0.store_to(params.dx0, idx); - idx += Ktraits::VEC_COLS_PER_LDG; + if (has_residual) { dx1.store_to(params.dx1, idx_x); } + if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); } + idx_x += Ktraits::VEC_COLS_PER_LDG; + idx_x0 += Ktraits::VEC_COLS_PER_LDG; } } @@ -434,64 +472,61 @@ void launch_(LaunchParams &launch_params, const bool configure_params WARPS_N, BYTES_PER_LDG_MAIN >; - bool prenorm = launch_params.params.dx != nullptr; bool is_dropout = launch_params.params.dropout_keep_p < 1.f; - bool has_residual = launch_params.params.dx1 != nullptr; bool has_colscale = launch_params.params.colscale != nullptr; + bool has_subset = launch_params.params.x0_subset != nullptr; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; - BOOL_SWITCH(prenorm, PrenormConst, [&] { - BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { - BOOL_SWITCH(has_residual, HasResidualConst, [&] { - BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_bwd_kernel; - if( configure_params ) { - int ctas_per_sm; - CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) - * 2; - } - return; + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { + BOOL_SWITCH(has_subset, HasSubsetConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_bwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::reduce_t) + * 2; } + return; + } - if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; + if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); - } + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); + } - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; + using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; - kernel_f<<>>(launch_params.params); - }); + auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; + kernel_f<<>>(launch_params.params); }); }); }); diff --git a/csrc/layer_norm/ln_fwd_kernels.cuh b/csrc/layer_norm/ln_fwd_kernels.cuh index 0c09a3b..4dc16eb 100644 --- a/csrc/layer_norm/ln_fwd_kernels.cuh +++ b/csrc/layer_norm/ln_fwd_kernels.cuh @@ -16,7 +16,7 @@ namespace layer_norm { -template +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(FwdParams params) { @@ -46,7 +46,8 @@ void ln_fwd_kernel(FwdParams params) { using Stats = typename Ktraits::Stats; using stats_t = typename Stats::stats_t; - const bool save_x = Has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || !(std::is_same::value); + const bool has_residual = params.x1 != nullptr; + const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same::value); extern __shared__ char smem_[]; @@ -67,6 +68,8 @@ void ln_fwd_kernel(FwdParams params) { compute_t *rs_ptr = static_cast(params.rs); const input_t *rowscale = static_cast(params.rowscale); + const index_t *x0_subset = static_cast(params.x0_subset); + const index_t *z_subset = static_cast(params.z_subset); // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu curandStatePhilox4_32_10_t state; @@ -93,8 +96,12 @@ void ln_fwd_kernel(FwdParams params) { } for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t rowscale_val = params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row]); - index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; + const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; + const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; + const int row_z = !Has_subset ? row + 1 : z_subset[row]; + const bool load_x0 = !Has_subset || row_x0 > 0; + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); compute_t xf[LDGS * NUM_ELTS]; #pragma unroll for( int it = 0; it < LDGS; it++ ) { @@ -103,24 +110,30 @@ void ln_fwd_kernel(FwdParams params) { Rvec x1; Rvec x; Mvec dmask; - x0.load_from(params.x0, idx); - if (Has_residual) { x1.load_from(params.x1, idx); } + if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } + if (has_residual) { x1.load_from(params.x1, idx_x); } #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use // the more efficient curand_uniform4. - mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; - compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; - x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; - if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } - compute_t x_ij = Has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij; + compute_t x_ij; + if (load_x0) { + mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; + if (Is_dropout) { dmask.data.elt[jt] = keep; } + compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; + x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; + if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } + x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij; + } else { + x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f; + } if (save_x) { x.data.elt[jt] = x_ij; } xf[it * NUM_ELTS + jt] = x_ij; - if (Is_dropout) { dmask.data.elt[jt] = keep; } } - if (save_x) { x.store_to(params.x, idx); } - if (Is_dropout) { dmask.store_to(params.dmask, idx); } - idx += VEC_COLS_PER_LDG; + if (save_x) { x.store_to(params.x, idx_x); } + if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); } + idx_x += VEC_COLS_PER_LDG; + idx_x0 += VEC_COLS_PER_LDG; } } @@ -152,20 +165,23 @@ void ln_fwd_kernel(FwdParams params) { rs_ptr[row] = rs; } - idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - if (Is_even_cols || (it < num_valid_ldgs)) { - Ovec z; - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - mu)); - compute_t g_ij = gamma[it].data.elt[jt]; - compute_t b_ij = beta[it].data.elt[jt]; - z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); + const bool save_z = !Has_subset || row_z > 0; + if (save_z) { + index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ovec z; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - mu)); + compute_t g_ij = gamma[it].data.elt[jt]; + compute_t b_ij = beta[it].data.elt[jt]; + z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); + } + z.store_to(params.z, idx_z); + idx_z += VEC_COLS_PER_LDG; } - z.store_to(params.z, idx); - idx += VEC_COLS_PER_LDG; } } @@ -203,14 +219,14 @@ void launch_(LaunchParams &launch_params, const bool configure_params WARPS_N, BYTES_PER_LDG >; - bool has_residual = launch_params.params.x1 != nullptr; bool has_colscale = launch_params.params.colscale != nullptr; + bool has_subset = launch_params.params.x0_subset != nullptr; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { - BOOL_SWITCH(has_residual, HasResidualConst, [&] { - BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { - BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { - auto kernel = &ln_fwd_kernel; + BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { + BOOL_SWITCH(has_subset, HasSubsetConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_fwd_kernel; if( configure_params ) { int ctas_per_sm; CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( diff --git a/flash_attn/ops/layer_norm.py b/flash_attn/ops/layer_norm.py index a70088b..bd81ca0 100644 --- a/flash_attn/ops/layer_norm.py +++ b/flash_attn/ops/layer_norm.py @@ -16,7 +16,8 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro x1mat = x1.view((-1, hidden_size)) if x1 is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( - x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32 + x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, + 1.0, 0, None, residual_in_fp32 ) # dmask is None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype @@ -36,12 +37,59 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro dxmat = dx.view(xmat.shape) if dx is not None else None x0mat = x0.view((-1, hidden_size)) if x0 is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None - colscale = colscale.view(-1) if colscale is not None else None if colscale is not None: assert x0 is not None, 'x0 is required to compute the gradient of colscale' dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( - dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, - has_residual + dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, + dropout_p, 1.0, 0, has_residual + ) + # dx1mat is None if not has_residual + if colscale is None: + return dx0mat, dx1mat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dx1mat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset, + dropout_p, epsilon, rowscale_const, out_numrows, + residual_in_fp32): + """ Assume that arguments are contiguous + """ + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, None, residual_in_fp32 + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, + x0_subset, out_subset, dropout_p, rowscale_const, + x0_numrows, has_residual): + """ Assume that arguments are contiguous + dx == None means that it was a post-norm architecture + (x = drop(x0) + x1 was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(-1, hidden_size) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + if colscale is not None: + assert x0 is not None, 'x0 is required to compute the gradient of colscale' + dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset, + dropout_p, rowscale_const, x0_numrows, has_residual ) # dx1mat is None if not has_residual if colscale is None: @@ -98,6 +146,60 @@ class DropoutAddLayerNormFn(torch.autograd.Function): return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None +class DropoutAddLayerNormSubsetFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, residual_in_fp32, prenorm=False, return_dmask=False): + x0 = x0.contiguous() + x1 = x1.contiguous() if x1 is not None else None + gamma = gamma.contiguous() + beta = beta.contiguous() + colscale = colscale.contiguous() if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( + x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, residual_in_fp32 + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + x_shape = (-1, *x0.shape[1:]) + ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale, + x0_subset, out_subset) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.rowscale_const = rowscale_const + ctx.x0_numrows = x0.shape[:-1].numel() + ctx.has_residual = x1 is not None + z_shape = (-1, *x0.shape[1:]) + if not return_dmask: + return (zmat.view(z_shape) if not prenorm + else (zmat.view(z_shape), xmat.view(x0.shape))) + else: + z = zmat.view(z_shape) + dmask = (dmask.view(x0.shape) if dropout_p > 0. + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + ctx.mark_non_differentiable(dmask) + return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = dz.contiguous() # this happens! + dx = args[0].contiguous() if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( + dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, + ctx.rowscale_const, ctx.x0_numrows, has_residual + ) + dx0 = dx0mat.view(-1, *x.shape[1:]) + dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + dcolscale = rest[0] if colscale is not None else None + return (dx0, dx1, dgamma, dbeta, dcolscale, None, None, None, None, None, None, None, + None, None) + + def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False): @@ -110,6 +212,19 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No ) +def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None, + x0_subset=None, out_subset=None, rowscale_const=1.0, + out_numrows=0, prenorm=False, residual_in_fp32=False, + return_dropout_mask=False): + """residual_in_fp32 only has an effect if x1 is None. + Otherwise residual dtype is x1.dtype. + """ + return DropoutAddLayerNormSubsetFn.apply( + x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, residual_in_fp32, prenorm, return_dropout_mask + ) + + class DropoutAddLayerNorm(torch.nn.Module): def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, device=None, dtype=None): diff --git a/tests/ops/test_dropout_layer_norm.py b/tests/ops/test_dropout_layer_norm.py index ebc0da4..9cd7e56 100644 --- a/tests/ops/test_dropout_layer_norm.py +++ b/tests/ops/test_dropout_layer_norm.py @@ -4,9 +4,10 @@ import torch import torch.nn.functional as F import pytest -from einops import rearrange +from einops import rearrange, repeat from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm +from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 @@ -130,6 +131,8 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh x1 = x1_pt.detach().clone().requires_grad_() x1_ref = x1_pt.detach().clone().float().requires_grad_() model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + torch.nn.init.normal_(model_pt.weight) + torch.nn.init.normal_(model_pt.bias) model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) with torch.no_grad(): @@ -148,22 +151,23 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 -@pytest.mark.parametrize('has_colscale', [True, False]) -@pytest.mark.parametrize('has_rowscale', [True, False]) -@pytest.mark.parametrize('has_residual', [True, False]) -@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) -@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) -@pytest.mark.parametrize('input_dtype,residual_dtype', - [(torch.float16, torch.float16), (torch.float16, torch.float32), - (torch.float32, torch.float32)] - + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) -# @pytest.mark.parametrize('has_colscale', [True]) -# @pytest.mark.parametrize('has_rowscale', [False]) -# @pytest.mark.parametrize('has_residual', [False]) -# @pytest.mark.parametrize('dropout_p', [0.0]) -# @pytest.mark.parametrize('weight_dtype', [torch.float32]) -# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) -@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) +# @pytest.mark.parametrize('has_colscale', [True, False]) +# @pytest.mark.parametrize('has_rowscale', [True, False]) +# @pytest.mark.parametrize('has_residual', [True, False]) +# @pytest.mark.parametrize('dropout_p', [0.37, 0.0]) +# @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) +# @pytest.mark.parametrize('input_dtype,residual_dtype', +# [(torch.float16, torch.float16), (torch.float16, torch.float32), +# (torch.float32, torch.float32)] +# + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) +@pytest.mark.parametrize('has_colscale', [True]) +@pytest.mark.parametrize('has_rowscale', [False]) +@pytest.mark.parametrize('has_residual', [True]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('weight_dtype', [torch.float32]) +@pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) +# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) +@pytest.mark.parametrize('hidden_size', [256]) def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_rowscale, has_colscale): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: @@ -205,6 +209,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ x0_scaled_pt = x0_scaled_pt * colscale_pt x0_scaled_ref = x0_scaled_ref * colscale_ref model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + torch.nn.init.normal_(model_pt.weight) + torch.nn.init.normal_(model_pt.bias) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype) @@ -271,6 +277,8 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp x1 = x1_pt.detach().clone().requires_grad_() x1_ref = x1_pt.detach().clone().float().requires_grad_() model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + torch.nn.init.normal_(model_pt.weight) + torch.nn.init.normal_(model_pt.bias) model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) @@ -289,3 +297,245 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp out_ref = model_ref(residual_ref) assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 + + +@pytest.mark.parametrize('has_colscale', [True, False]) +@pytest.mark.parametrize('has_residual', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) +@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) +@pytest.mark.parametrize('input_dtype,residual_dtype', + [(torch.float16, torch.float16), (torch.float16, torch.float32), + (torch.float32, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) +# @pytest.mark.parametrize('has_colscale', [True]) +# @pytest.mark.parametrize('has_residual', [True]) +# @pytest.mark.parametrize('dropout_p', [0.0]) +# @pytest.mark.parametrize('weight_dtype', [torch.float32]) +# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) +@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) +# @pytest.mark.parametrize('hidden_size', [256]) +def test_dropout_layer_norm_subset_training( + hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, + has_residual, has_colscale): + if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: + pytest.skip() # Not supported + device = 'cuda' + # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) + rtol, atol = (1e-3, 2e-4) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + drop_path_rate = 0.4 + drop_path_scale = 1 / (1 - drop_path_rate) + def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device): + # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync + mask_batch = torch.rand(batch_size) < 1 - drop_path_rate + numrows = (mask_batch).sum().item() * seqlen + mask_batch = mask_batch.to(device=device, non_blocking=True) + mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen) + subset = torch.cumsum(mask_batch_seqlen, dim=0, + dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0) + return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size) + + x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen, + drop_path_rate, device) + out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen, + drop_path_rate, device) + + x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, + requires_grad=True) + x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_() + x0_ref = x0_pt.detach().clone().float().requires_grad_() + if has_colscale: + colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) + colscale_pt = colscale.detach().clone().requires_grad_() + colscale_ref = colscale.detach().clone().float().requires_grad_() + else: + colscale = None + if has_residual: + x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True) + x1 = x1_pt.detach().clone().requires_grad_() + x1_ref = x1_pt.detach().clone().float().requires_grad_() + else: + x1 = None + + if has_colscale: + x0_scaled_pt = x0_pt * colscale_pt + x0_scaled_ref = x0_ref * colscale_ref + else: + x0_scaled_pt = x0_pt + x0_scaled_ref = x0_ref + + model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + torch.nn.init.normal_(model_pt.weight) + torch.nn.init.normal_(model_pt.bias) + model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) + model = DropoutAddLayerNorm(hidden_size, prenorm=False, p=dropout_p, device=device, + dtype=weight_dtype) + with torch.no_grad(): + model.weight.copy_(model_pt.weight) + model.bias.copy_(model_pt.bias) + model_ref.weight.copy_(model_pt.weight) + model_ref.bias.copy_(model_pt.bias) + + residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 + out, dmask = dropout_add_layer_norm_subset( + x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale, + x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale, + out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32, + return_dropout_mask=True) + print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') + + x0_scaled_pt = x0_scaled_pt.masked_fill( + repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0 + ) * drop_path_scale + x0_scaled_ref = x0_scaled_ref.masked_fill( + repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0 + ) * drop_path_scale + dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8) + dmask_expanded[x0_mask_batch] = dmask + if has_residual: + residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref + else: + residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch] + out_ref = model_ref(residual_ref)[out_mask_batch] + assert out.dtype == input_dtype + assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 + + g = torch.randn_like(out) / batch_size + out_pt.backward(g) + out.backward(g) + out_ref.backward(g) + assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4 + if has_residual: + assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 + assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 + assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 + if has_colscale: + assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4 + + +@pytest.mark.parametrize('has_colscale', [True, False]) +@pytest.mark.parametrize('has_residual', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) +@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) +@pytest.mark.parametrize('input_dtype,residual_dtype', + [(torch.float16, torch.float16), (torch.float16, torch.float32), + (torch.float32, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) +# @pytest.mark.parametrize('has_colscale', [True]) +# @pytest.mark.parametrize('has_residual', [True]) +# @pytest.mark.parametrize('dropout_p', [0.0]) +# @pytest.mark.parametrize('weight_dtype', [torch.float32]) +# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) +@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144]) +# @pytest.mark.parametrize('hidden_size', [256]) +def test_dropout_layer_norm_subset_prenorm_training( + hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, + has_residual, has_colscale): + if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: + pytest.skip() # Not supported + device = 'cuda' + # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) + rtol, atol = (1e-3, 2e-4) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + drop_path_rate = 0.4 + drop_path_scale = 1 / (1 - drop_path_rate) + def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device): + # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync + mask_batch = torch.rand(batch_size) < 1 - drop_path_rate + numrows = (mask_batch).sum().item() * seqlen + mask_batch = mask_batch.to(device=device, non_blocking=True) + mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen) + subset = torch.cumsum(mask_batch_seqlen, dim=0, + dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0) + return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size) + + x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen, + drop_path_rate, device) + out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen, + drop_path_rate, device) + + x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, + requires_grad=True) + x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_() + x0_ref = x0_pt.detach().clone().float().requires_grad_() + if has_colscale: + colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) + colscale_pt = colscale.detach().clone().requires_grad_() + colscale_ref = colscale.detach().clone().float().requires_grad_() + else: + colscale = None + if has_residual: + x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True) + x1 = x1_pt.detach().clone().requires_grad_() + x1_ref = x1_pt.detach().clone().float().requires_grad_() + else: + x1 = None + + if has_colscale: + x0_scaled_pt = x0_pt * colscale_pt + x0_scaled_ref = x0_ref * colscale_ref + else: + x0_scaled_pt = x0_pt + x0_scaled_ref = x0_ref + + model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + torch.nn.init.normal_(model_pt.weight) + torch.nn.init.normal_(model_pt.bias) + model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) + model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, + dtype=weight_dtype) + with torch.no_grad(): + model.weight.copy_(model_pt.weight) + model.bias.copy_(model_pt.bias) + model_ref.weight.copy_(model_pt.weight) + model_ref.bias.copy_(model_pt.bias) + + residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 + out, residual, dmask = dropout_add_layer_norm_subset( + x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale, + x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale, + out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32, + return_dropout_mask=True) + print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') + + x0_scaled_pt = x0_scaled_pt.masked_fill( + repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0 + ) * drop_path_scale + x0_scaled_ref = x0_scaled_ref.masked_fill( + repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0 + ) * drop_path_scale + dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8) + dmask_expanded[x0_mask_batch] = dmask + if has_residual: + residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref + else: + residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch] + out_ref = model_ref(residual_ref)[out_mask_batch] + assert out.dtype == input_dtype + assert residual.dtype == residual_dtype + assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 + assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 + + g = torch.randn_like(out) / batch_size + (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(g) + (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g) + (out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g) + assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4 + if has_residual: + assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 + assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 + assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 + if has_colscale: + assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4