[LayerNorm] Support taking subset of input or subset of output

This commit is contained in:
Tri Dao 2022-12-12 22:16:14 -08:00
parent ae137ed17a
commit 5db330519a
6 changed files with 643 additions and 168 deletions

View File

@ -66,11 +66,14 @@ struct ParamsBase {
void *gamma;
void *rowscale;
void *colscale;
void *x0_subset;
void *z_subset;
float inverse_cols;
float dropout_keep_p;
float dropout_scale;
float rowscale_const;
// Multi-CTA workspace in gmem.
void *workspace;

View File

@ -84,9 +84,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // hidden_size
c10::optional<const at::Tensor> &x0_subset_, // BxS
c10::optional<const at::Tensor> &z_subset_, // BxS
const float dropout_p,
const float epsilon,
const float rowscale_const,
const int64_t z_numrows,
c10::optional<at::Generator> gen_,
bool residual_in_fp32
) {
@ -99,14 +103,19 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8;
TORCH_CHECK(beta.scalar_type() == wtype);
TORCH_CHECK(beta.dtype() == wtype);
TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(gamma.is_cuda())
TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(x0.is_contiguous());
auto sizes = x0.sizes();
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
// blah is then deallocated.
std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};
auto sizes = c10::IntArrayRef(sizes_vec);
TORCH_CHECK(x0.dim() == 2);
TORCH_CHECK(sizes.size() == 2);
const int rows = sizes[0];
@ -124,7 +133,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(rowscale.dtype() == itype);
}
@ -132,10 +141,25 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto colscale = colscale_.value();
TORCH_CHECK(colscale.is_cuda())
TORCH_CHECK(colscale.is_contiguous());
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
TORCH_CHECK(colscale.dtype() == wtype);
}
if (x0_subset_.has_value()) {
auto x0_subset = x0_subset_.value();
TORCH_CHECK(x0_subset.is_cuda())
TORCH_CHECK(x0_subset.is_contiguous());
TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
TORCH_CHECK(z_subset_.has_value());
auto z_subset = z_subset_.value();
TORCH_CHECK(z_subset.is_cuda());
TORCH_CHECK(z_subset.is_contiguous());
TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
}
TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
@ -144,12 +168,12 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto opts = x0.options();
bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || (itype != rtype);
bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
at::Tensor x;
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
at::Tensor dmask;
if (dropout_p > 0.f) { dmask = torch::empty(sizes, opts.dtype(mtype)); };
auto z = torch::empty(sizes, opts.dtype(otype));
if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };
auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));
auto mu = torch::empty({ rows }, opts.dtype(ctype));
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
@ -163,6 +187,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
@ -192,6 +218,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params.epsilon = epsilon;
params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = rowscale_const;
if (dropout_p > 0.f) {
// number of times random will be generated per thread, to offset philox counter in thc random
@ -230,8 +257,12 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // hidden_size
c10::optional<const at::Tensor> &x0_subset_, // BxS
c10::optional<const at::Tensor> &z_subset_, // BxS
const float dropout_p,
const float rowscale_const,
const int64_t x0_numrows,
const bool has_residual
) {
@ -259,9 +290,16 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
TORCH_CHECK(dz.dim() == 2);
TORCH_CHECK(dz.size(1) == cols);
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
// Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
// blah is then deallocated.
std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};
auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);
if (dx_.has_value()) {
auto dx = dx_.value();
@ -276,14 +314,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
TORCH_CHECK(dmask.dtype() == mtype);
TORCH_CHECK(dmask.is_cuda());
TORCH_CHECK(dmask.is_contiguous());
TORCH_CHECK(dmask.sizes() == sizes);
TORCH_CHECK(dmask.sizes() == x0_sizes);
}
if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(rowscale.dtype() == itype);
}
@ -291,17 +329,32 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto colscale = colscale_.value();
TORCH_CHECK(colscale.is_cuda())
TORCH_CHECK(colscale.is_contiguous());
TORCH_CHECK(colscale.sizes() == std::vector<int64_t>{cols});
TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
TORCH_CHECK(colscale.dtype() == wtype);
TORCH_CHECK(x0_.has_value());
auto x0 = x0_.value();
TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(x0.is_contiguous());
TORCH_CHECK(x0.sizes() == sizes);
TORCH_CHECK(x0.sizes() == x0_sizes);
TORCH_CHECK(x0.dtype() == itype);
}
if (x0_subset_.has_value()) {
auto x0_subset = x0_subset_.value();
TORCH_CHECK(x0_subset.is_cuda())
TORCH_CHECK(x0_subset.is_contiguous());
TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(x0_subset.dtype() == torch::kInt32);
TORCH_CHECK(z_subset_.has_value());
auto z_subset = z_subset_.value();
TORCH_CHECK(z_subset.is_cuda());
TORCH_CHECK(z_subset.is_contiguous());
TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
}
auto hidden_size = gamma.numel();
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
@ -313,7 +366,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
auto opts = x.options();
auto dx0 = torch::empty_like(x, opts.dtype(itype));
auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
at::Tensor dx1;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma);
@ -331,6 +384,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
@ -366,6 +421,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = rowscale_const;
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?

View File

@ -7,7 +7,7 @@
namespace layer_norm {
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) {
@ -37,6 +37,9 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
extern __shared__ char smem_[];
const bool has_residual = params.dx1 != nullptr;
const bool prenorm = params.dx != nullptr;
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
@ -51,6 +54,10 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
Cvec dcolscale_sum[LDGS];
@ -87,40 +94,62 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
const compute_t rowscale_val =
params.rowscale == nullptr ? 1.0f : compute_t(static_cast<const input_t *>(params.rowscale)[row]);
const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
const int row_z = !Has_subset ? row + 1 : z_subset[row];
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
const bool load_dz = !Has_subset || row_z > 0;
const bool save_dx0 = !Has_subset || row_x0 > 0;
Mvec dmask[LDGS];
Rvec dx[LDGS];
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Rvec x;
Ovec dz;
dz.load_from(params.dz, idx);
if (Prenorm) { dx[it].load_from(params.dx, idx); }
x.load_from(params.x, idx);
if (Is_dropout) { dmask[it].load_from(params.dmask, idx); }
idx += Ktraits::VEC_COLS_PER_LDG;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x.data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
compute_t dz_tmp = dz.data.elt[jt];
// If dz is not loaded, then dy should be 0 and we don't care about the value of y.
if (load_dz) {
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Rvec x;
Ovec dz;
dz.load_from(params.dz, !Has_subset ? idx_x : idx_z);
if (prenorm) { dx[it].load_from(params.dx, idx_x); }
x.load_from(params.x, idx_x);
if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_z += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x.data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
compute_t dz_tmp = dz.data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
}
} else {
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
if (prenorm) { dx[it].load_from(params.dx, idx_x); }
if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }
idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
}
}
}
@ -129,42 +158,51 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec dx0;
Rvec dx1;
Ivec x0;
if (Has_colscale) { x0.load_from(params.x0, idx); }
if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) {
dx0_tmp_res *= params.dropout_scale;
if (Has_colscale) {
dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
} else {
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
}
compute_t dx_tmp_res;
if (load_dz) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
} else {
if (Has_colscale) {
dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
}
if (has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
if (save_dx0) {
compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) {
dx0_tmp_res *= params.dropout_scale;
if (Has_colscale) {
dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;
} else {
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;
}
} else {
dx0.data.elt[jt] = dx0_tmp_res;
if (Has_colscale) {
dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);
dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);
} else {
dx0.data.elt[jt] = dx0_tmp_res;
}
}
}
}
if (Has_residual) { dx1.store_to(params.dx1, idx); }
dx0.store_to(params.dx0, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
if (has_residual) { dx1.store_to(params.dx1, idx_x); }
if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }
idx_x += Ktraits::VEC_COLS_PER_LDG;
idx_x0 += Ktraits::VEC_COLS_PER_LDG;
}
}
@ -434,64 +472,61 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
WARPS_N,
BYTES_PER_LDG_MAIN
>;
bool prenorm = launch_params.params.dx != nullptr;
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
bool has_residual = launch_params.params.dx1 != nullptr;
bool has_colscale = launch_params.params.colscale != nullptr;
bool has_subset = launch_params.params.x0_subset != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(prenorm, PrenormConst, [&] {
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HasColscaleConst,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HasColscaleConst,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
});
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
});
});
});

View File

@ -16,7 +16,7 @@
namespace layer_norm {
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_colscale, bool Is_even_cols>
template<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
@ -46,7 +46,8 @@ void ln_fwd_kernel(FwdParams params) {
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
const bool save_x = Has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || !(std::is_same<input_t, residual_t>::value);
const bool has_residual = params.x1 != nullptr;
const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[];
@ -67,6 +68,8 @@ void ln_fwd_kernel(FwdParams params) {
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);
const index_t *z_subset = static_cast<index_t *>(params.z_subset);
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
curandStatePhilox4_32_10_t state;
@ -93,8 +96,12 @@ void ln_fwd_kernel(FwdParams params) {
}
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t rowscale_val = params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row]);
index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;
const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];
const int row_z = !Has_subset ? row + 1 : z_subset[row];
const bool load_x0 = !Has_subset || row_x0 > 0;
index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;
index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
@ -103,24 +110,30 @@ void ln_fwd_kernel(FwdParams params) {
Rvec x1;
Rvec x;
Mvec dmask;
x0.load_from(params.x0, idx);
if (Has_residual) { x1.load_from(params.x1, idx); }
if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }
if (has_residual) { x1.load_from(params.x1, idx_x); }
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4.
mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
compute_t x_ij = Has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
compute_t x_ij;
if (load_x0) {
mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }
x_ij = has_residual ? x0_ij + compute_t(x1.data.elt[jt]) : x0_ij;
} else {
x_ij = has_residual ? compute_t(x1.data.elt[jt]) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
}
if (save_x) { x.store_to(params.x, idx); }
if (Is_dropout) { dmask.store_to(params.dmask, idx); }
idx += VEC_COLS_PER_LDG;
if (save_x) { x.store_to(params.x, idx_x); }
if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }
idx_x += VEC_COLS_PER_LDG;
idx_x0 += VEC_COLS_PER_LDG;
}
}
@ -152,20 +165,23 @@ void ln_fwd_kernel(FwdParams params) {
rs_ptr[row] = rs;
}
idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ovec z;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - mu));
compute_t g_ij = gamma[it].data.elt[jt];
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
const bool save_z = !Has_subset || row_z > 0;
if (save_z) {
index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ovec z;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - mu));
compute_t g_ij = gamma[it].data.elt[jt];
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
}
z.store_to(params.z, idx_z);
idx_z += VEC_COLS_PER_LDG;
}
z.store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
}
@ -203,14 +219,14 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
WARPS_N,
BYTES_PER_LDG
>;
bool has_residual = launch_params.params.x1 != nullptr;
bool has_colscale = launch_params.params.colscale != nullptr;
bool has_subset = launch_params.params.x0_subset != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasColscaleConst, IsEvenColsConst>;
BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {
BOOL_SWITCH(has_subset, HasSubsetConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(

View File

@ -16,7 +16,8 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32
x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
1.0, 0, None, residual_in_fp32
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
@ -36,12 +37,59 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
colscale = colscale.view(-1) if colscale is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p,
has_residual
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
dropout_p, 1.0, 0, has_residual
)
# dx1mat is None if not has_residual
if colscale is None:
return dx0mat, dx1mat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dx1mat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
dropout_p, epsilon, rowscale_const, out_numrows,
residual_in_fp32):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, None, residual_in_fp32
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
x0_subset, out_subset, dropout_p, rowscale_const,
x0_numrows, has_residual):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(-1, hidden_size)
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
if colscale is not None:
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
dropout_p, rowscale_const, x0_numrows, has_residual
)
# dx1mat is None if not has_residual
if colscale is None:
@ -98,6 +146,60 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
x_shape = (-1, *x0.shape[1:])
ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale,
x0_subset, out_subset)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.rowscale_const = rowscale_const
ctx.x0_numrows = x0.shape[:-1].numel()
ctx.has_residual = x1 is not None
z_shape = (-1, *x0.shape[1:])
if not return_dmask:
return (zmat.view(z_shape) if not prenorm
else (zmat.view(z_shape), xmat.view(x0.shape)))
else:
z = zmat.view(z_shape)
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask)
return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
dx = args[0].contiguous() if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
ctx.rowscale_const, ctx.x0_numrows, has_residual
)
dx0 = dx0mat.view(-1, *x.shape[1:])
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dx1, dgamma, dbeta, dcolscale, None, None, None, None, None, None, None,
None, None)
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
@ -110,6 +212,19 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
)
def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, return_dropout_mask
)
class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):

View File

@ -4,9 +4,10 @@ import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from einops import rearrange, repeat
from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@ -130,6 +131,8 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
with torch.no_grad():
@ -148,22 +151,23 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [False])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('has_colscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('input_dtype,residual_dtype',
# [(torch.float16, torch.float16), (torch.float16, torch.float32),
# (torch.float32, torch.float32)]
# + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('has_colscale', [True])
@pytest.mark.parametrize('has_rowscale', [False])
@pytest.mark.parametrize('has_residual', [True])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32])
@pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
@ -205,6 +209,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
@ -271,6 +277,8 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
@ -289,3 +297,245 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
has_residual, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0,
dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_colscale:
x0_scaled_pt = x0_pt * colscale_pt
x0_scaled_ref = x0_ref * colscale_ref
else:
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=False, p=dropout_p, device=device,
dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm_subset(
x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
x0_scaled_pt = x0_scaled_pt.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
x0_scaled_ref = x0_scaled_ref.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch]
assert out.dtype == input_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
out_pt.backward(g)
out.backward(g)
out_ref.backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_prenorm_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p,
has_residual, has_colscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, 'b -> (b s)', s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0,
dtype=torch.int32).masked_fill_(~mask_batch_seqlen, 0)
return mask_batch, numrows, rearrange(subset, '(b s) -> b s', b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(batch_size, seqlen,
drop_path_rate, device)
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_colscale:
x0_scaled_pt = x0_pt * colscale_pt
x0_scaled_ref = x0_ref * colscale_ref
else:
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm_subset(
x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
x0_scaled_pt = x0_scaled_pt.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
x0_scaled_ref = x0_scaled_ref.masked_fill(
repeat(~x0_mask_batch, 'b -> b s d', s=seqlen, d=hidden_size), 0
) * drop_path_scale
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch]
assert out.dtype == input_dtype
assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(g)
(out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
(out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4