Bug fix: wrong smem_o write pointer for d=16
This commit is contained in:
parent
765741c1ee
commit
eeca63a72a
@ -101,7 +101,7 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
|
||||
// Set this to probability of keeping an element to simplify things.
|
||||
params.p_dropout = 1.f - p_dropout;
|
||||
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
||||
// [Minor] We want to round down since when we do the comparison we use <= instead <
|
||||
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
||||
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
||||
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
|
||||
@ -1204,6 +1204,8 @@ struct Smem_tile_o {
|
||||
this->smem_write_ ^= 7 * 32;
|
||||
} else if( Mma_tile::MMAS_N >= 2 ) {
|
||||
this->smem_write_ ^= 3 * 32;
|
||||
} else {
|
||||
this->smem_write_ ^= 3 * 32;
|
||||
}
|
||||
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user