Implement rotary embedding in flash_attn_with_kvcache
This commit is contained in:
parent
5400fdc4ac
commit
ccbb14f38e
@ -13,7 +13,9 @@
|
||||
#include "flash.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
|
||||
void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
@ -260,9 +262,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
@ -299,7 +299,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
@ -426,17 +426,15 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
CHECK_DEVICE(cu_seqlens_q);
|
||||
CHECK_DEVICE(cu_seqlens_k);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
|
||||
CHECK_CONTIGUOUS(cu_seqlens_q);
|
||||
CHECK_CONTIGUOUS(cu_seqlens_k);
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
@ -471,7 +469,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
@ -610,12 +608,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
|
||||
TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
|
||||
TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
@ -657,7 +651,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
if (dq_.has_value()) {
|
||||
dq = dq_.value();
|
||||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||
TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
|
||||
CHECK_DEVICE(dq);
|
||||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
||||
} else {
|
||||
@ -666,7 +660,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
if (dk_.has_value()) {
|
||||
dk = dk_.value();
|
||||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||
TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
|
||||
CHECK_DEVICE(dk);
|
||||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
} else {
|
||||
@ -675,7 +669,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
if (dv_.has_value()) {
|
||||
dv = dv_.value();
|
||||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||
TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
|
||||
CHECK_DEVICE(dv);
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
} else {
|
||||
@ -820,22 +814,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
|
||||
TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
|
||||
TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
||||
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
|
||||
CHECK_CONTIGUOUS(cu_seqlens_q);
|
||||
CHECK_CONTIGUOUS(cu_seqlens_k);
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
@ -873,7 +862,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
if (dq_.has_value()) {
|
||||
dq = dq_.value();
|
||||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||
TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
|
||||
CHECK_DEVICE(dq);
|
||||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
||||
} else {
|
||||
@ -882,7 +871,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
if (dk_.has_value()) {
|
||||
dk = dk_.value();
|
||||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||
TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
|
||||
CHECK_DEVICE(dk);
|
||||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
|
||||
} else {
|
||||
@ -891,7 +880,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
if (dv_.has_value()) {
|
||||
dv = dv_.value();
|
||||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||
TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
|
||||
CHECK_DEVICE(dv);
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
|
||||
} else {
|
||||
@ -1000,9 +989,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
|
||||
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
||||
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
||||
int num_splits
|
||||
) {
|
||||
|
||||
@ -1023,9 +1015,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(kcache.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(vcache.is_cuda(), "Input tensor must be on CUDA device");
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
@ -1071,7 +1061,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
@ -1118,8 +1108,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
v = v_.value();
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
|
||||
TORCH_CHECK(k.is_cuda(), "Key tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device");
|
||||
CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
|
||||
int seqlen_knew = k.size(1);
|
||||
@ -1147,13 +1136,40 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
if (seqlens_k_.has_value()) {
|
||||
auto seqlens_k = seqlens_k_.value();
|
||||
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
||||
TORCH_CHECK(seqlens_k.is_cuda(), "seqlens_k must be on CUDA device");
|
||||
TORCH_CHECK(seqlens_k.is_contiguous(), "seqlens_k must be contiguous");
|
||||
CHECK_DEVICE(seqlens_k);
|
||||
CHECK_CONTIGUOUS(seqlens_k);
|
||||
CHECK_SHAPE(seqlens_k, batch_size);
|
||||
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
|
||||
}
|
||||
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
|
||||
|
||||
if (rotary_cos_.has_value()) {
|
||||
TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
|
||||
auto rotary_cos = rotary_cos_.value();
|
||||
CHECK_DEVICE(rotary_cos);
|
||||
params.rotary_dim = rotary_cos.size(1) * 2;
|
||||
TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
|
||||
TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
|
||||
const int seqlen_ro = rotary_cos.size(0);
|
||||
TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
|
||||
CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
|
||||
CHECK_CONTIGUOUS(rotary_cos);
|
||||
TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
|
||||
|
||||
TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
|
||||
auto rotary_sin = rotary_sin_.value();
|
||||
CHECK_DEVICE(rotary_sin);
|
||||
CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
|
||||
CHECK_CONTIGUOUS(rotary_sin);
|
||||
TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
|
||||
params.rotary_cos_ptr = rotary_cos.data_ptr();
|
||||
params.rotary_sin_ptr = rotary_sin.data_ptr();
|
||||
params.is_rotary_interleaved = is_rotary_interleaved;
|
||||
} else {
|
||||
params.rotary_dim = 0;
|
||||
}
|
||||
|
||||
|
||||
// This needs to match with run_mha_fwd_splitkv_dispatch
|
||||
const int block_n = is_sm90 || is_sm8x
|
||||
? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
|
||||
|
||||
@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
void * __restrict__ softmax_lseaccum_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
@ -91,6 +91,10 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
index_t knew_head_stride;
|
||||
index_t vnew_head_stride;
|
||||
|
||||
// The cos and sin matrices for rotary embedding.
|
||||
void * __restrict__ rotary_cos_ptr;
|
||||
void * __restrict__ rotary_sin_ptr;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
// uint32_t p_dropout_in_uint;
|
||||
@ -114,6 +118,8 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
bool is_seqlens_k_cumulative;
|
||||
|
||||
bool is_rotary_interleaved;
|
||||
|
||||
int num_splits; // For split-KV version
|
||||
};
|
||||
|
||||
|
||||
@ -744,10 +744,36 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
// Prologue
|
||||
|
||||
// Copy from Knew to K, optionally apply rotary embedding.
|
||||
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
|
||||
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
|
||||
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
|
||||
if constexpr (Append_KV) {
|
||||
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
|
||||
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
|
||||
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
|
||||
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
|
||||
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
|
||||
make_stride(params.rotary_dim / 2, _1{}));
|
||||
Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
|
||||
make_stride(params.rotary_dim / 2, _1{}));
|
||||
Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.rotary_dim / 2, _1{}));
|
||||
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.rotary_dim / 2, _1{}));
|
||||
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
|
||||
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
|
||||
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
|
||||
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
|
||||
// if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
|
||||
// if (cute::thread(8, 0)) { print_tensor(gCos); }
|
||||
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
|
||||
|
||||
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
|
||||
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
|
||||
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
|
||||
@ -769,17 +795,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
|
||||
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
|
||||
if (params.rotary_dim == 0) {
|
||||
flash::copy_w_min_idx<Is_even_K>(
|
||||
tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
} else {
|
||||
if (params.is_rotary_interleaved) {
|
||||
// Don't clear OOB_K because we're writing to global memory
|
||||
flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
|
||||
tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
|
||||
binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
|
||||
);
|
||||
tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
|
||||
tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
|
||||
} else {
|
||||
// Don't clear OOB_K because we're writing to global memory
|
||||
flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
|
||||
tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
|
||||
binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
|
||||
);
|
||||
tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
|
||||
tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
|
||||
|
||||
}
|
||||
}
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
|
||||
}
|
||||
// Need this before we can read in K again, so that we'll see the updated K values.
|
||||
__syncthreads();
|
||||
if (n_block_max > n_block_copy_min) {
|
||||
tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
|
||||
@ -787,10 +835,44 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
}
|
||||
}
|
||||
|
||||
// Read Q from gmem to smem, optionally apply rotary embedding.
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
if (!Append_KV || params.rotary_dim == 0) {
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
} else {
|
||||
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
|
||||
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
|
||||
// We do this by setting the row stride of gCos / gSin to 0.
|
||||
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
|
||||
make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{}));
|
||||
Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
|
||||
make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{}));
|
||||
Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{}));
|
||||
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{}));
|
||||
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
|
||||
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
|
||||
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
|
||||
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
|
||||
if (params.is_rotary_interleaved) {
|
||||
flash::copy_rotary_interleaved<Is_even_K>(
|
||||
tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
|
||||
0, params.d, params.rotary_dim
|
||||
);
|
||||
} else {
|
||||
flash::copy_rotary_contiguous<Is_even_K>(
|
||||
tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
|
||||
0, params.d, params.rotary_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||
|
||||
@ -142,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
@ -155,7 +155,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
@ -170,6 +170,15 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
|
||||
using GmemTiledCopyRotcossin = decltype(
|
||||
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
|
||||
using GmemTiledCopyRotcossinCont = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
|
||||
};
|
||||
|
||||
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
||||
|
||||
@ -355,43 +355,6 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_2_sources=false, bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S0,
|
||||
Tensor<Engine0, Layout0> const &S1,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K,
|
||||
const int max_MN=0, const int row_idx_switch=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K
|
||||
// There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S0); ++m) {
|
||||
auto &S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1;
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S0); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
} else if (Clear_OOB_MN) {
|
||||
cute::clear(D(_, m, _));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
@ -422,4 +385,137 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
|
||||
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
cute::copy(Cos(_, m, k), rCos(_, m, k));
|
||||
cute::copy(Sin(_, m, k), rSin(_, m, k));
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS) / 2; ++i) {
|
||||
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
|
||||
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
|
||||
S_fp32(2 * i) = real;
|
||||
S_fp32(2 * i + 1) = imag;
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
|
||||
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
|
||||
cute::copy(gS_other, rS_other);
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
|
||||
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
|
||||
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
|
||||
cute::copy(gCos, rCos(_, m, k));
|
||||
cute::copy(gSin, rSin(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor S_other_fp32 = convert_type<float>(rS_other);
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS); ++i) {
|
||||
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
|
||||
@ -175,7 +175,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
|
||||
TORCH_CHECK(rotary_sin_.has_value());
|
||||
auto rotary_sin = rotary_sin_.value();
|
||||
CHECK_DEVICE(rotary_sin);
|
||||
CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
|
||||
CHECK_SHAPE(rotary_sin, batch_size, rotary_embedding_dim / 2);
|
||||
CHECK_CONTIGUOUS(rotary_sin);
|
||||
TORCH_CHECK(rotary_sin.scalar_type() == input_type);
|
||||
}
|
||||
|
||||
@ -800,9 +800,12 @@ def flash_attn_with_kvcache(
|
||||
v_cache,
|
||||
k=None,
|
||||
v=None,
|
||||
rotary_cos=None,
|
||||
rotary_sin=None,
|
||||
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
rotary_interleaved=True,
|
||||
num_splits=0,
|
||||
):
|
||||
"""
|
||||
@ -815,7 +818,13 @@ def flash_attn_with_kvcache(
|
||||
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
||||
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
||||
|
||||
Does not support backward pass.
|
||||
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated
|
||||
by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||
If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens,
|
||||
cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin
|
||||
at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
||||
|
||||
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
||||
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||
@ -834,6 +843,8 @@ def flash_attn_with_kvcache(
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
Note: Does not support backward pass.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
|
||||
@ -841,11 +852,18 @@ def flash_attn_with_kvcache(
|
||||
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
||||
k with k_cache, starting at the indices specified by cache_seqlens.
|
||||
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
|
||||
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
||||
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
||||
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
||||
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
||||
KV cache.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
||||
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
||||
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
||||
(i.e. GPT-NeoX style).
|
||||
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
||||
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
||||
to automatically determine the number of splits.
|
||||
@ -865,6 +883,18 @@ def flash_attn_with_kvcache(
|
||||
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
||||
)
|
||||
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
|
||||
q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k,
|
||||
v,
|
||||
cache_seqlens,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
None,
|
||||
softmax_scale,
|
||||
causal,
|
||||
rotary_interleaved,
|
||||
num_splits,
|
||||
)
|
||||
return out
|
||||
|
||||
@ -280,7 +280,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
|
||||
|
||||
@pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
|
||||
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
|
||||
@pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
|
||||
@pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"])
|
||||
# @pytest.mark.parametrize('rotary', [None])
|
||||
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
|
||||
# @pytest.mark.parametrize("fused_ft_kernel", [False])
|
||||
|
||||
@ -15,6 +15,7 @@ from flash_attn import (
|
||||
)
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import _get_block_size
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
MAX_HEADDIM_SM8x = 192
|
||||
|
||||
@ -1497,12 +1498,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
|
||||
@pytest.mark.parametrize("new_kv", [False, True])
|
||||
# @pytest.mark.parametrize("new_kv", [True])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [False])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
|
||||
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
|
||||
@pytest.mark.parametrize("rotary_interleaved", [False, True])
|
||||
# @pytest.mark.parametrize("rotary_interleaved", [False])
|
||||
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
|
||||
# @pytest.mark.parametrize("rotary_fraction", [1.0])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize("d", [64])
|
||||
@pytest.mark.parametrize(
|
||||
@ -1523,15 +1528,29 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
|
||||
)
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
||||
def test_flash_attn_kvcache(
|
||||
seqlen_q, seqlen_k, d, seqlen_new_eq_seqlen_q, causal, new_kv, mha_type, num_splits, dtype
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
d,
|
||||
rotary_fraction,
|
||||
rotary_interleaved,
|
||||
seqlen_new_eq_seqlen_q,
|
||||
causal,
|
||||
new_kv,
|
||||
mha_type,
|
||||
num_splits,
|
||||
dtype,
|
||||
):
|
||||
if seqlen_q > seqlen_k and new_kv:
|
||||
pytest.skip()
|
||||
if not new_kv and rotary_fraction > 0.0:
|
||||
pytest.skip()
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 2
|
||||
nheads = 6
|
||||
# rotary_dim must be a multiple of 16, and must be <= d
|
||||
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
|
||||
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
||||
assert nheads % nheads_k == 0
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
|
||||
@ -1545,12 +1564,42 @@ def test_flash_attn_kvcache(
|
||||
v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
||||
cache_seqlens = torch.randint(
|
||||
0,
|
||||
(seqlen_k - seqlen_new + 1) if new_kv else (seqlen_k + 1),
|
||||
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
|
||||
(seqlen_k - (seqlen_q if causal and rotary_dim > 1 else seqlen_new) + 1)
|
||||
if new_kv
|
||||
else (seqlen_k + 1),
|
||||
(batch_size,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
|
||||
if rotary_dim > 0:
|
||||
angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
if causal:
|
||||
q_ro = apply_rotary_emb(
|
||||
q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
|
||||
)
|
||||
else:
|
||||
q_ro = rearrange(
|
||||
apply_rotary_emb(
|
||||
rearrange(q, "b s h d -> b 1 (s h) d"),
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=cache_seqlens,
|
||||
interleaved=rotary_interleaved,
|
||||
),
|
||||
"b 1 (s h) d -> b s h d",
|
||||
s=seqlen_q,
|
||||
)
|
||||
# q_ro = q
|
||||
k_ro = apply_rotary_emb(
|
||||
k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
|
||||
)
|
||||
else:
|
||||
cos, sin = None, None
|
||||
q_ro, k_ro = q, k
|
||||
# k_cache[:, 64:] = -1
|
||||
k_cache_ref = k_cache.clone()
|
||||
v_cache_ref = v_cache.clone()
|
||||
@ -1560,12 +1609,22 @@ def test_flash_attn_kvcache(
|
||||
update_mask = torch.logical_and(
|
||||
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
|
||||
)
|
||||
k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
|
||||
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
|
||||
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
|
||||
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
||||
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
|
||||
out = flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k,
|
||||
v,
|
||||
cos,
|
||||
sin,
|
||||
cache_seqlens,
|
||||
causal=causal,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
|
||||
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
|
||||
@ -1577,10 +1636,10 @@ def test_flash_attn_kvcache(
|
||||
# probs = torch.softmax(qk, dim=-1)
|
||||
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
|
||||
out_ref, _ = attention_ref(
|
||||
q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal
|
||||
q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal
|
||||
)
|
||||
out_pt, _ = attention_ref(
|
||||
q,
|
||||
q_ro,
|
||||
k_cache_rep,
|
||||
v_cache_rep,
|
||||
None,
|
||||
@ -1598,10 +1657,10 @@ def test_flash_attn_kvcache(
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
|
||||
if new_kv:
|
||||
assert torch.equal(k_cache, k_cache_ref)
|
||||
assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3)
|
||||
assert torch.equal(v_cache, v_cache_ref)
|
||||
assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user