diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 1621d67..36a3692 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -101,7 +101,7 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms, // Set this to probability of keeping an element to simplify things. params.p_dropout = 1.f - p_dropout; // Convert p from float to int so we don't have to convert the random uint to float to compare. - // [Minor] We want to round down since when we do the comparison we use <= instead < + // [Minor] We want to round down since when we do the comparison we use <= instead of < params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); params.rp_dropout = 1.f / params.p_dropout; @@ -111,7 +111,7 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms, params.is_causal = is_causal; } -std::vector +std::vector mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens, // b+1 const float p_dropout, diff --git a/csrc/flash_attn/src/fmha/smem_tile.h b/csrc/flash_attn/src/fmha/smem_tile.h index 22b307d..18579e5 100644 --- a/csrc/flash_attn/src/fmha/smem_tile.h +++ b/csrc/flash_attn/src/fmha/smem_tile.h @@ -1204,6 +1204,8 @@ struct Smem_tile_o { this->smem_write_ ^= 7 * 32; } else if( Mma_tile::MMAS_N >= 2 ) { this->smem_write_ ^= 3 * 32; + } else { + this->smem_write_ ^= 3 * 32; } // this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {