diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 3697d21..773b1c3 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -657,10 +657,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) - + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) - + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -672,18 +668,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) - + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) - + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); @@ -698,10 +682,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; @@ -762,6 +744,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Prologue + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + } + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, @@ -769,10 +794,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN - ); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // flash::cp_async_wait<0>(); @@ -800,32 +823,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (cute::thread0()) { print(tKgK); } - // if (cute::thread0()) { print(tKsK); } - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN - ); - } - // __syncthreads(); - // if (cute::thread0()) { print(tKgK); } - // __syncthreads(); - } - // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN - ); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); @@ -856,26 +861,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); } - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN - ); - } - } - if (n_block > n_block_min) { // Advance gK - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); } tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); } - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN - ); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -909,20 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN - ); - } - } // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN - ); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -932,22 +910,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN - ); - } - } if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN - ); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 04298f4..6b60ac9 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -149,8 +149,9 @@ def generate_qkv( ) -def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, - device=None): +def construct_causal_mask( + seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None +): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) sk = ( @@ -364,12 +365,18 @@ def convert_flash_attn_S_to_softmax( causal_mask = construct_causal_mask( seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device ) - causal_mask = F.pad(causal_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True) + causal_mask = F.pad( + causal_mask, + (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), + value=True, + ) S_converted.masked_fill_(causal_mask, 0.0) # Need to zero out things not in attention_mask in case S was initialized with random values # and some of those values aren't overwritten. - seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded + seqlen_q_og = ( + query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded + ) if query_padding_mask is not None: query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) @@ -623,7 +630,14 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal + S_dmask, + seqlen, + seqlen, + key_padding_mask, + key_padding_mask, + d, + dropout_p > 0.0, + causal=causal, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() @@ -996,7 +1010,14 @@ def test_flash_attn_varlen_output( out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal + S_dmask, + seqlen_q, + seqlen_k, + query_padding_mask, + key_padding_mask, + d, + dropout_p > 0.0, + causal=causal, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() @@ -1466,16 +1487,18 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4 + @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) # @pytest.mark.parametrize("num_splits", [0]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize("mha_type", ["mqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False, True]) -# @pytest.mark.parametrize("new_kv", [False]) +# @pytest.mark.parametrize("new_kv", [True]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -1499,7 +1522,9 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num_splits, dtype): +def test_flash_attn_kvcache( + seqlen_q, seqlen_k, d, seqlen_new_eq_seqlen_q, causal, new_kv, mha_type, num_splits, dtype +): if seqlen_q > seqlen_k and new_kv: pytest.skip() device = "cuda" @@ -1510,14 +1535,21 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() if new_kv: - k = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype) - v = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) - cache_seqlens = torch.randint(0, (seqlen_k - seqlen_q + 1) if new_kv else (seqlen_k + 1), (batch_size, ), dtype=torch.int32, device=device) + cache_seqlens = torch.randint( + 0, + (seqlen_k - seqlen_new + 1) if new_kv else (seqlen_k + 1), + (batch_size,), + dtype=torch.int32, + device=device, + ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) # k_cache[:, 64:] = -1 k_cache_ref = k_cache.clone() @@ -1525,12 +1557,16 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") if new_kv: - update_mask = torch.logical_and(cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_q) + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new + ) k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits) + out = flash_attn_with_kvcache( + q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits + ) # out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal) # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) @@ -1539,10 +1575,22 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) - key_padding_mask = arange < cache_seqlens_expanded + (seqlen_q if new_kv else 0) - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal) - out_pt, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal, - upcast=False, reorder_ops=True) + key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal + ) + out_pt, _ = attention_ref( + q, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=causal, + upcast=False, + reorder_ops=True, + ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") @@ -1583,7 +1631,7 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num (1024, 1024), ], ) -@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): device = "cuda"