From b1fbbd833715505ad2f1fa131ad5e821a8f5dc57 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 29 Aug 2023 00:58:29 -0700 Subject: [PATCH] Implement splitKV attention --- csrc/flash_attn/flash_api.cpp | 67 ++- csrc/flash_attn/src/flash.h | 5 + .../src/flash_bwd_launch_template.h | 12 +- csrc/flash_attn/src/flash_fwd_kernel.h | 569 ++++++++++++++++++ .../src/flash_fwd_launch_template.h | 78 ++- .../src/flash_fwd_split_hdim128_bf16_sm80.cu | 7 + .../src/flash_fwd_split_hdim128_fp16_sm80.cu | 7 + .../src/flash_fwd_split_hdim160_bf16_sm80.cu | 7 + .../src/flash_fwd_split_hdim160_fp16_sm80.cu | 7 + .../src/flash_fwd_split_hdim192_bf16_sm80.cu | 7 + .../src/flash_fwd_split_hdim192_fp16_sm80.cu | 7 + .../src/flash_fwd_split_hdim224_bf16_sm80.cu | 7 + .../src/flash_fwd_split_hdim224_fp16_sm80.cu | 7 + .../src/flash_fwd_split_hdim256_bf16_sm80.cu | 7 + .../src/flash_fwd_split_hdim256_fp16_sm80.cu | 7 + .../src/flash_fwd_split_hdim32_bf16_sm80.cu | 7 + .../src/flash_fwd_split_hdim32_fp16_sm80.cu | 7 + .../src/flash_fwd_split_hdim64_bf16_sm80.cu | 7 + .../src/flash_fwd_split_hdim64_fp16_sm80.cu | 7 + .../src/flash_fwd_split_hdim96_bf16_sm80.cu | 7 + .../src/flash_fwd_split_hdim96_fp16_sm80.cu | 7 + csrc/flash_attn/src/generate_kernels.py | 22 +- csrc/flash_attn/src/kernel_traits.h | 14 +- setup.py | 16 + tests/test_flash_attn.py | 103 ++++ 25 files changed, 980 insertions(+), 18 deletions(-) create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu create mode 100644 csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 6dd30f0..d2d40b4 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -178,11 +178,57 @@ void set_params_dgrad(Flash_bwd_params ¶ms, void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_(params, stream); + if (params.num_splits <= 1) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } }); }); } +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + std::vector mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size @@ -294,6 +340,25 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head softmax_scale, is_causal); + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = is_sm90 || is_sm8x + ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)) + : (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + params.num_splits = 1; + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 64); + if (params.num_splits > 1) { + at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + } + } + // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index e65d7d5..b8ab225 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -53,6 +53,7 @@ struct Flash_fwd_params : public Qkv_params { // The O matrix (output). void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; @@ -64,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params { // The pointer to the softmax sum. void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; // The dimensions. int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; @@ -96,6 +98,8 @@ struct Flash_fwd_params : public Qkv_params { bool is_bf16; bool is_causal; + + int num_splits; // For split-KV version }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -140,5 +144,6 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 867e228..e4c36ea 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -64,7 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { + if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } @@ -75,7 +75,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, }); auto kernel_dq = &flash_bwd_convert_dq_kernel; - if (Kernel_traits::kSmemdQSize >= 48 * 1024) { + if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); } @@ -103,7 +103,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { + if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } @@ -114,7 +114,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, }); auto kernel_dkv = &flash_bwd_convert_dkv_kernel; - if (Kernel_traits::kSmemKVSize >= 48 * 1024) { + if constexpr(Kernel_traits::kSmemKVSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize)); } @@ -147,7 +147,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con // BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { // // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // if (smem_size_dq_dk_dv >= 48 * 1024) { + // if constexpr(smem_size_dq_dk_dv >= 48 * 1024) { // C10_CUDA_CHECK(cudaFuncSetAttribute( // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); // } @@ -159,7 +159,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con // }); // auto kernel_dq = &flash_bwd_convert_dq_kernel; - // if (Kernel_traits::kSmemdQSize >= 48 * 1024) { + // if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) { // C10_CUDA_CHECK(cudaFuncSetAttribute( // kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); // } diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index e6aaf85..b8dce9e 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -617,6 +617,407 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = n_split_idx * n_blocks_per_split; + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = -INFINITY; } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal + ? 1 + : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + } else { + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor lse = make_fragment_like(scores_sum); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + + // if (cute::thread0()) { print(acc_o_rowcol); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomOaccum{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(acc_o); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; @@ -638,4 +1039,172 @@ inline __device__ void compute_attn(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void compute_attn_splitkv(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.z / params.h; + // The block index for the head. + const int bidh = blockIdx.z - bidb * params.h; + const int n_split_idx = blockIdx.y; + const int num_n_splits = gridDim.y; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kBlockM = 16; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); + static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); + static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { sLSE[row][col] = lse; } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // 16 rows, so each time we load we can load 8 rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max == lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + ElementAccum lse_logsum = logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } + // Store the scales exp(lse - lse_logsum) in shared memory. + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } + } + // Load Oaccum in then scale and accumulate to O + #pragma unroll 2 + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; + #pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { + #pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print(tOrO); } + + Tensor rO = flash::convert_type(tOrO); + // Write to gO + #pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + + head_idx * params.o_head_stride + row * params.o_row_stride; + #pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + } // namespace flash diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index c8d3af6..d036be4 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -15,6 +15,17 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) { flash::compute_attn(params); } +template +__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { + flash::compute_attn_splitkv(params); +} + +template +__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +} + template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr size_t smem_size = Kernel_traits::kSmemSize; @@ -35,13 +46,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // Will only return softmax if dropout, to reduce compilation time. auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { + if constexpr(smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -50,6 +61,65 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } +template +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits, params.b * params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + // TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit. + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if constexpr(smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + // } else if (params.num_splits <= 128) { + // flash_fwd_splitkv_combine_kernel<<>>(params); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + constexpr int kBlockM = 64; // Fixed for all head dimensions + if (!is_sm8x) { // A100, H100 + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); + } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); + } +} + template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr int Headdim = 32; diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu new file mode 100644 index 0000000..477c560 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 0000000..914cd23 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu new file mode 100644 index 0000000..d753d59 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 0000000..552c25d --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu new file mode 100644 index 0000000..e6b350a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 0000000..b9c1935 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu new file mode 100644 index 0000000..b6bf081 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 0000000..0d09606 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu new file mode 100644 index 0000000..06a9524 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 0000000..54fd3b8 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu new file mode 100644 index 0000000..beff74c --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 0000000..d97c9ea --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu new file mode 100644 index 0000000..aed05fa --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 0000000..3b905f6 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu new file mode 100644 index 0000000..00a5972 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 0000000..95a7696 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index cf4d808..62e6d06 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -16,14 +16,21 @@ DTYPE_MAP = { SM = [80] # Sm80 kernels support up to HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256] -KERNEL_IMPL_TEMPLATE_FWD = """ +KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h" + template<> void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream); }} """ -KERNEL_IMPL_TEMPLATE_BWD = """ +KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream); +""" + +KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h" + template<> void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{ run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure); @@ -44,10 +51,14 @@ class Kernel: return KERNEL_IMPL_TEMPLATE_FWD.format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim ) - else: + elif self.direction == "bwd": return KERNEL_IMPL_TEMPLATE_BWD.format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim ) + else: + return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim + ) @property def filename(self) -> str: @@ -56,7 +67,7 @@ class Kernel: def get_all_kernels() -> List[Kernel]: for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): - for direction in ["fwd", "bwd"]: + for direction in ["fwd", "bwd", "fwd_split"]: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction) @@ -65,8 +76,7 @@ def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py"\n """ - include = f'#include "flash_{kernel.direction}_launch_template.h"\n' - (autogen_dir / kernel.filename).write_text(prelude + include + kernel.template) + (autogen_dir / kernel.filename).write_text(prelude + kernel.template) def main(output_dir: Optional[str]) -> None: diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index d1a9513..d5e8a76 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -113,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; @@ -158,6 +159,17 @@ struct Flash_fwd_kernel_traits : public Base { GmemLayoutAtomP{}, Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/setup.py b/setup.py index 37cc2b5..cdb6f33 100644 --- a/setup.py +++ b/setup.py @@ -173,6 +173,22 @@ if not SKIP_CUDA_BUILD: "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"] + generator_flag, diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 12fcc53..8c6cff9 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1367,6 +1367,109 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 +@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("swap_sq_sk", [False, True]) +# @pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (3, 1024), + (1, 339), + (3, 799), + (64, 2048), + (16, 20000), + (16, 100000), + (128, 128), + (256, 256), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 1 + nheads = 12 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + out, lse, _ = flash_attn_func(q, k, v, 0.0, causal=causal, return_attn_probs=True) + out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + 0.0, + None, + causal=causal, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 2e-4 + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4 + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4 + + # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True])