Simplify the implementation of KVcache attn by appending KV first
This commit is contained in:
parent
d0032700d1
commit
56b7fc6ee0
@ -657,10 +657,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
||||
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
||||
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
@ -672,18 +668,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.v_row_stride, _1{}));
|
||||
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
|
||||
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
|
||||
// This maps to accessing the first 64 rows of knew_ptr.
|
||||
Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
|
||||
+ row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.knew_row_stride, _1{}));
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
|
||||
Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
|
||||
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.vnew_row_stride, _1{}));
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutQ{});
|
||||
@ -698,10 +682,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
||||
Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
|
||||
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
||||
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
|
||||
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
@ -762,6 +744,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
// Prologue
|
||||
|
||||
if constexpr (Append_KV) {
|
||||
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
|
||||
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
|
||||
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
|
||||
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
||||
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
|
||||
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
|
||||
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
|
||||
// This maps to accessing the first 64 rows of knew_ptr.
|
||||
Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
|
||||
+ row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.knew_row_stride, _1{}));
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
|
||||
Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
|
||||
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.vnew_row_stride, _1{}));
|
||||
Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
|
||||
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
|
||||
|
||||
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
|
||||
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
|
||||
}
|
||||
__syncthreads();
|
||||
if (n_block_max > n_block_copy_min) {
|
||||
tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
|
||||
tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
@ -769,10 +794,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||
flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
|
||||
gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||
cute::cp_async_fence();
|
||||
|
||||
// flash::cp_async_wait<0>();
|
||||
@ -800,32 +823,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (Append_KV) {
|
||||
// if (cute::thread0()) { print(tKgK); }
|
||||
// if (cute::thread0()) { print(tKsK); }
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
||||
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
// __syncthreads();
|
||||
// if (cute::thread0()) { print(tKgK); }
|
||||
// __syncthreads();
|
||||
}
|
||||
|
||||
// Advance gV
|
||||
if (masking_step > 0) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
|
||||
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
} else {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
cute::cp_async_fence();
|
||||
@ -856,26 +861,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
|
||||
// __syncthreads();
|
||||
|
||||
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
|
||||
if constexpr (Append_KV) {
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
||||
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_block > n_block_min) {
|
||||
// Advance gK
|
||||
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
|
||||
// if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
|
||||
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
||||
gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
|
||||
binfo.seqlen_k_cache - (n_block - 1) * kBlockN
|
||||
);
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
@ -909,20 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
clear(acc_s);
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if constexpr (Append_KV) {
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
||||
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
}
|
||||
// Advance gV
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
|
||||
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
cute::cp_async_fence();
|
||||
|
||||
flash::gemm(
|
||||
@ -932,22 +910,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if constexpr (Append_KV) {
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
|
||||
if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
}
|
||||
if (n_block > n_block_min) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
|
||||
flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
|
||||
gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
|
||||
binfo.seqlen_k_cache - (n_block - 1) * kBlockN
|
||||
);
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
|
||||
@ -149,8 +149,9 @@ def generate_qkv(
|
||||
)
|
||||
|
||||
|
||||
def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None,
|
||||
device=None):
|
||||
def construct_causal_mask(
|
||||
seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None
|
||||
):
|
||||
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
||||
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
||||
sk = (
|
||||
@ -364,12 +365,18 @@ def convert_flash_attn_S_to_softmax(
|
||||
causal_mask = construct_causal_mask(
|
||||
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device
|
||||
)
|
||||
causal_mask = F.pad(causal_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True)
|
||||
causal_mask = F.pad(
|
||||
causal_mask,
|
||||
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
|
||||
value=True,
|
||||
)
|
||||
S_converted.masked_fill_(causal_mask, 0.0)
|
||||
|
||||
# Need to zero out things not in attention_mask in case S was initialized with random values
|
||||
# and some of those values aren't overwritten.
|
||||
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
|
||||
seqlen_q_og = (
|
||||
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
|
||||
)
|
||||
if query_padding_mask is not None:
|
||||
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
|
||||
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
|
||||
@ -623,7 +630,14 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
out = output_pad_fn(out_unpad)
|
||||
if dropout_p > 0.0:
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
|
||||
S_dmask,
|
||||
seqlen,
|
||||
seqlen,
|
||||
key_padding_mask,
|
||||
key_padding_mask,
|
||||
d,
|
||||
dropout_p > 0.0,
|
||||
causal=causal,
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
attn_unnorm = S_dmask_converted.abs()
|
||||
@ -996,7 +1010,14 @@ def test_flash_attn_varlen_output(
|
||||
out = output_pad_fn(out_unpad)
|
||||
if dropout_p > 0.0:
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
|
||||
S_dmask,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
d,
|
||||
dropout_p > 0.0,
|
||||
causal=causal,
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
attn_unnorm = S_dmask_converted.abs()
|
||||
@ -1466,16 +1487,18 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("num_splits", [1, 0])
|
||||
# @pytest.mark.parametrize("num_splits", [0])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["mqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["mha"])
|
||||
@pytest.mark.parametrize("new_kv", [False, True])
|
||||
# @pytest.mark.parametrize("new_kv", [False])
|
||||
# @pytest.mark.parametrize("new_kv", [True])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
# @pytest.mark.parametrize("causal", [False])
|
||||
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
@ -1499,7 +1522,9 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
|
||||
],
|
||||
)
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
||||
def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num_splits, dtype):
|
||||
def test_flash_attn_kvcache(
|
||||
seqlen_q, seqlen_k, d, seqlen_new_eq_seqlen_q, causal, new_kv, mha_type, num_splits, dtype
|
||||
):
|
||||
if seqlen_q > seqlen_k and new_kv:
|
||||
pytest.skip()
|
||||
device = "cuda"
|
||||
@ -1510,14 +1535,21 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
|
||||
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
||||
assert nheads % nheads_k == 0
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
|
||||
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
|
||||
if new_kv:
|
||||
k = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
|
||||
v = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
|
||||
k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
|
||||
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
|
||||
else:
|
||||
k, v = None, None
|
||||
k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
||||
v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
||||
cache_seqlens = torch.randint(0, (seqlen_k - seqlen_q + 1) if new_kv else (seqlen_k + 1), (batch_size, ), dtype=torch.int32, device=device)
|
||||
cache_seqlens = torch.randint(
|
||||
0,
|
||||
(seqlen_k - seqlen_new + 1) if new_kv else (seqlen_k + 1),
|
||||
(batch_size,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
|
||||
# k_cache[:, 64:] = -1
|
||||
k_cache_ref = k_cache.clone()
|
||||
@ -1525,12 +1557,16 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
|
||||
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
|
||||
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
|
||||
if new_kv:
|
||||
update_mask = torch.logical_and(cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_q)
|
||||
update_mask = torch.logical_and(
|
||||
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
|
||||
)
|
||||
k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
|
||||
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
|
||||
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
||||
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
||||
out = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits)
|
||||
out = flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits
|
||||
)
|
||||
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
|
||||
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
|
||||
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
|
||||
@ -1539,10 +1575,22 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
|
||||
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
|
||||
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
|
||||
# probs = torch.softmax(qk, dim=-1)
|
||||
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_q if new_kv else 0)
|
||||
out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal)
|
||||
out_pt, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal,
|
||||
upcast=False, reorder_ops=True)
|
||||
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
|
||||
out_ref, _ = attention_ref(
|
||||
q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal
|
||||
)
|
||||
out_pt, _ = attention_ref(
|
||||
q,
|
||||
k_cache_rep,
|
||||
v_cache_rep,
|
||||
None,
|
||||
key_padding_mask,
|
||||
0.0,
|
||||
None,
|
||||
causal=causal,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
@ -1583,7 +1631,7 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
|
||||
(1024, 1024),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
|
||||
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
||||
# @pytest.mark.parametrize("dropout_p", [0.0])
|
||||
def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
|
||||
device = "cuda"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user