From ccbb14f38ee1e51a6f65bccbef9e4765acba6d79 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 16 Sep 2023 01:20:16 -0700 Subject: [PATCH] Implement rotary embedding in flash_attn_with_kvcache --- csrc/flash_attn/flash_api.cpp | 100 ++++++++------ csrc/flash_attn/src/flash.h | 8 +- csrc/flash_attn/src/flash_fwd_kernel.h | 98 ++++++++++++-- csrc/flash_attn/src/kernel_traits.h | 15 ++- csrc/flash_attn/src/utils.h | 172 +++++++++++++++++++------ csrc/ft_attention/ft_attention.cpp | 2 +- flash_attn/flash_attn_interface.py | 34 ++++- tests/models/test_gpt.py | 2 +- tests/test_flash_attn.py | 79 ++++++++++-- 9 files changed, 404 insertions(+), 106 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 3051eab..140ba02 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -13,7 +13,9 @@ #include "flash.h" #include "static_switch.h" +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") void set_params_fprop(Flash_fwd_params ¶ms, @@ -260,9 +262,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -299,7 +299,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); + CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } @@ -426,17 +426,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device"); - TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device"); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous"); - TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); @@ -471,7 +469,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); + CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, total_q, num_heads, head_size_og); if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } @@ -610,12 +608,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device"); - TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device"); - TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device"); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -657,7 +651,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); - TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device"); + CHECK_DEVICE(dq); TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); } else { @@ -666,7 +660,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si if (dk_.has_value()) { dk = dk_.value(); TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); - TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device"); + CHECK_DEVICE(dk); TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); } else { @@ -675,7 +669,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si if (dv_.has_value()) { dv = dv_.value(); TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); - TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device"); + CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); } else { @@ -820,22 +814,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device"); - TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device"); - TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device"); - TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device"); - TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device"); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous"); - TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); @@ -873,7 +862,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); - TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device"); + CHECK_DEVICE(dq); TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); CHECK_SHAPE(dq, total_q, num_heads, head_size); } else { @@ -882,7 +871,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (dk_.has_value()) { dk = dk_.value(); TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); - TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device"); + CHECK_DEVICE(dk); TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); CHECK_SHAPE(dk, total_k, num_heads_k, head_size); } else { @@ -891,7 +880,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (dv_.has_value()) { dv = dv_.value(); TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); - TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device"); + CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, total_k, num_heads_k, head_size); } else { @@ -1000,9 +989,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &seqlens_k_, // batch_size + c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits ) { @@ -1023,9 +1015,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(kcache.is_cuda(), "Input tensor must be on CUDA device"); - TORCH_CHECK(vcache.is_cuda(), "Input tensor must be on CUDA device"); + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -1071,7 +1061,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device"); + CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } @@ -1118,8 +1108,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he v = v_.value(); TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query"); TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query"); - TORCH_CHECK(k.is_cuda(), "Key tensor must be on CUDA device"); - TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device"); + CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); int seqlen_knew = k.size(1); @@ -1147,13 +1136,40 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (seqlens_k_.has_value()) { auto seqlens_k = seqlens_k_.value(); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); - TORCH_CHECK(seqlens_k.is_cuda(), "seqlens_k must be on CUDA device"); - TORCH_CHECK(seqlens_k.is_contiguous(), "seqlens_k must be contiguous"); + CHECK_DEVICE(seqlens_k); + CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); params.cu_seqlens_k = static_cast(seqlens_k.data_ptr()); } params.is_seqlens_k_cumulative = !(seqlens_k_.has_value()); + if (rotary_cos_.has_value()) { + TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + CHECK_CONTIGUOUS(rotary_cos); + TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + CHECK_CONTIGUOUS(rotary_sin); + TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + } else { + params.rotary_dim = 0; + } + + // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = is_sm90 || is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)) diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 0dbf928..e04507f 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; // The scaling factors for the kernel. float scale_softmax; @@ -91,6 +91,10 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride; index_t vnew_head_stride; + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; @@ -114,6 +118,8 @@ struct Flash_fwd_params : public Qkv_params { // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative; + bool is_rotary_interleaved; + int num_splits; // For split-KV version }; diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 773b1c3..d3736be 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -744,10 +744,36 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); if constexpr (Append_KV) { // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) @@ -769,17 +795,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { - flash::copy_w_min_idx( - tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN - ); flash::copy_w_min_idx( tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); } + // Need this before we can read in K again, so that we'll see the updated K values. __syncthreads(); if (n_block_max > n_block_copy_min) { tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; @@ -787,10 +835,44 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } } + // Read Q from gmem to smem, optionally apply rotary embedding. Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } + } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index e7d605f..f000ff2 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -142,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base { DefaultCopy >; using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; @@ -155,7 +155,7 @@ struct Flash_fwd_kernel_traits : public Base { Stride, _1>>; using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, GmemLayoutAtomP{}, Layout>{})); // Val layout, 8 vals per store @@ -170,6 +170,15 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 492090d..edf6a60 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -355,43 +355,6 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const &S0, - Tensor const &S1, - Tensor &D, Tensor const &identity_MN, - Tensor const &predicate_K, - const int max_MN=0, const int row_idx_switch=0) { - CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); - // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); } - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); } - #pragma unroll - for (int m = 0; m < size<1>(S0); ++m) { - auto &S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1; - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S0); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template @@ -422,4 +385,137 @@ inline __device__ void copy_w_min_idx(Tensor const &S, //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace flash \ No newline at end of file +template +inline __device__ void copy_rotary_interleaved(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_contiguous(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp index 38f5a1c..b307cff 100644 --- a/csrc/ft_attention/ft_attention.cpp +++ b/csrc/ft_attention/ft_attention.cpp @@ -175,7 +175,7 @@ torch::Tensor single_query_attention(const torch::Tensor q, TORCH_CHECK(rotary_sin_.has_value()); auto rotary_sin = rotary_sin_.value(); CHECK_DEVICE(rotary_sin); - CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2); + CHECK_SHAPE(rotary_sin, batch_size, rotary_embedding_dim / 2); CHECK_CONTIGUOUS(rotary_sin); TORCH_CHECK(rotary_sin.scalar_type() == input_type); } diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index d3c1305..1fe3b89 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -800,9 +800,12 @@ def flash_attn_with_kvcache( v_cache, k=None, v=None, + rotary_cos=None, + rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, softmax_scale=None, causal=False, + rotary_interleaved=True, num_splits=0, ): """ @@ -815,7 +818,13 @@ def flash_attn_with_kvcache( For example, the KV cache could be pre-allocated with the max sequence length, and you can use cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - Does not support backward pass. + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated + by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, + cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin + at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. @@ -834,6 +843,8 @@ def flash_attn_with_kvcache( 1 1 If the row of the mask is all zero, the output will be zero. + Note: Does not support backward pass. + Arguments: q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size, seqlen_cache, nheads_k, headdim) @@ -841,11 +852,18 @@ def flash_attn_with_kvcache( k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). num_splits: int. If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. @@ -865,6 +883,18 @@ def flash_attn_with_kvcache( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) out, softmax_lse = flash_attn_cuda.fwd_kvcache( - q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + rotary_cos, + rotary_sin, + None, + softmax_scale, + causal, + rotary_interleaved, + num_splits, ) return out diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index e72a9b9..2aa85e3 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -280,7 +280,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): @pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)]) # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) -@pytest.mark.parametrize("rotary", [None, "interleaved", "block"]) +@pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"]) # @pytest.mark.parametrize('rotary', [None]) @pytest.mark.parametrize("fused_ft_kernel", [False, True]) # @pytest.mark.parametrize("fused_ft_kernel", [False]) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 6b60ac9..e4c9843 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -15,6 +15,7 @@ from flash_attn import ( ) from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size +from flash_attn.layers.rotary import apply_rotary_emb MAX_HEADDIM_SM8x = 192 @@ -1497,12 +1498,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [True]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +# @pytest.mark.parametrize("rotary_fraction", [1.0]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) -# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( @@ -1523,15 +1528,29 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_kvcache( - seqlen_q, seqlen_k, d, seqlen_new_eq_seqlen_q, causal, new_kv, mha_type, num_splits, dtype + seqlen_q, + seqlen_k, + d, + rotary_fraction, + rotary_interleaved, + seqlen_new_eq_seqlen_q, + causal, + new_kv, + mha_type, + num_splits, + dtype, ): if seqlen_q > seqlen_k and new_kv: pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 nheads = 6 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) @@ -1545,12 +1564,42 @@ def test_flash_attn_kvcache( v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) cache_seqlens = torch.randint( 0, - (seqlen_k - seqlen_new + 1) if new_kv else (seqlen_k + 1), + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + (seqlen_k - (seqlen_q if causal and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1), (batch_size,), dtype=torch.int32, device=device, ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + if rotary_dim > 0: + angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + if causal: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k # k_cache[:, 64:] = -1 k_cache_ref = k_cache.clone() v_cache_ref = v_cache.clone() @@ -1560,12 +1609,22 @@ def test_flash_attn_kvcache( update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new ) - k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) out = flash_attn_with_kvcache( - q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits + q, + k_cache, + v_cache, + k, + v, + cos, + sin, + cache_seqlens, + causal=causal, + rotary_interleaved=rotary_interleaved, + num_splits=num_splits, ) # out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal) # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal) @@ -1577,10 +1636,10 @@ def test_flash_attn_kvcache( # probs = torch.softmax(qk, dim=-1) key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal ) out_pt, _ = attention_ref( - q, + q_ro, k_cache_rep, v_cache_rep, None, @@ -1598,10 +1657,10 @@ def test_flash_attn_kvcache( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5 if new_kv: - assert torch.equal(k_cache, k_cache_ref) + assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3) assert torch.equal(v_cache, v_cache_ref) + assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))