Simplify the implementation of KVcache attn by appending KV first

This commit is contained in:
Tri Dao 2023-09-13 15:55:48 -07:00
parent d0032700d1
commit 56b7fc6ee0
2 changed files with 119 additions and 105 deletions

View File

@ -657,10 +657,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
@ -672,18 +668,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// This maps to accessing the first 64 rows of knew_ptr.
Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
+ row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.knew_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.vnew_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
@ -698,10 +682,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
typename Kernel_traits::TiledMma tiled_mma;
@ -762,6 +744,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Prologue
if constexpr (Append_KV) {
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
// This maps to accessing the first 64 rows of knew_ptr.
Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
+ row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.knew_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.vnew_row_stride, _1{}));
Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
flash::copy_w_min_idx<Is_even_K>(
tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
flash::copy_w_min_idx<Is_even_K>(
tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
}
__syncthreads();
if (n_block_max > n_block_copy_min) {
tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
}
}
Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
@ -769,10 +794,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// flash::cp_async_wait<0>();
@ -800,32 +823,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::cp_async_wait<0>();
__syncthreads();
if constexpr (Append_KV) {
// if (cute::thread0()) { print(tKgK); }
// if (cute::thread0()) { print(tKsK); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
flash::copy_w_min_idx<Is_even_K>(
tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
}
// __syncthreads();
// if (cute::thread0()) { print(tKgK); }
// __syncthreads();
}
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
@ -856,26 +861,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
// __syncthreads();
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
if constexpr (Append_KV) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
flash::copy_w_min_idx<Is_even_K>(
tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
}
}
if (n_block > n_block_min) {
// Advance gK
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
binfo.seqlen_k_cache - (n_block - 1) * kBlockN
);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@ -909,20 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
if constexpr (Append_KV) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
flash::copy_w_min_idx<Is_even_K>(
tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
}
}
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm(
@ -932,22 +910,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::cp_async_wait<0>();
__syncthreads();
if constexpr (Append_KV) {
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
flash::copy_w_min_idx<Is_even_K>(
tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
);
}
}
if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
binfo.seqlen_k_cache - (n_block - 1) * kBlockN
);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();

View File

@ -149,8 +149,9 @@ def generate_qkv(
)
def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None,
device=None):
def construct_causal_mask(
seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None
):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
@ -364,12 +365,18 @@ def convert_flash_attn_S_to_softmax(
causal_mask = construct_causal_mask(
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device
)
causal_mask = F.pad(causal_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True)
causal_mask = F.pad(
causal_mask,
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True,
)
S_converted.masked_fill_(causal_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
seqlen_q_og = (
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
)
if query_padding_mask is not None:
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
@ -623,7 +630,14 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
S_dmask,
seqlen,
seqlen,
key_padding_mask,
key_padding_mask,
d,
dropout_p > 0.0,
causal=causal,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
@ -996,7 +1010,14 @@ def test_flash_attn_varlen_output(
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
S_dmask,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
d,
dropout_p > 0.0,
causal=causal,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
@ -1466,16 +1487,18 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [0])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False])
# @pytest.mark.parametrize("new_kv", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
@ -1499,7 +1522,9 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num_splits, dtype):
def test_flash_attn_kvcache(
seqlen_q, seqlen_k, d, seqlen_new_eq_seqlen_q, causal, new_kv, mha_type, num_splits, dtype
):
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
device = "cuda"
@ -1510,14 +1535,21 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
if new_kv:
k = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
else:
k, v = None, None
k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
cache_seqlens = torch.randint(0, (seqlen_k - seqlen_q + 1) if new_kv else (seqlen_k + 1), (batch_size, ), dtype=torch.int32, device=device)
cache_seqlens = torch.randint(
0,
(seqlen_k - seqlen_new + 1) if new_kv else (seqlen_k + 1),
(batch_size,),
dtype=torch.int32,
device=device,
)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# k_cache[:, 64:] = -1
k_cache_ref = k_cache.clone()
@ -1525,12 +1557,16 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
if new_kv:
update_mask = torch.logical_and(cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_q)
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
)
k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
out = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits)
out = flash_attn_with_kvcache(
q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits
)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
@ -1539,10 +1575,22 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_q if new_kv else 0)
out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal)
out_pt, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal,
upcast=False, reorder_ops=True)
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
out_ref, _ = attention_ref(
q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal
)
out_pt, _ = attention_ref(
q,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
0.0,
None,
causal=causal,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
@ -1583,7 +1631,7 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
(1024, 1024),
],
)
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
device = "cuda"