Bug fix: wrong smem_o write pointer for d=16

This commit is contained in:
Tri Dao 2022-06-25 15:17:39 -07:00
parent 765741c1ee
commit eeca63a72a
2 changed files with 4 additions and 2 deletions

View File

@ -101,7 +101,7 @@ void set_params(Fused_multihead_attention_fprop_params &params,
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead <
// [Minor] We want to round down since when we do the comparison we use <= instead of <
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout;

View File

@ -1204,6 +1204,8 @@ struct Smem_tile_o {
this->smem_write_ ^= 7 * 32;
} else if( Mma_tile::MMAS_N >= 2 ) {
this->smem_write_ ^= 3 * 32;
} else {
this->smem_write_ ^= 3 * 32;
}
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {