diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 2121241..b5f905b 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -975,9 +975,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Layout p_l = tPrP.layout(); // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); - // flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); + // flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); - // flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); + // flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 8d8d2ff..8d83839 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -369,11 +369,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if (Has_alibi) { flash::apply_alibi( - scores, - n_block * kBlockN, + scores, + n_block * kBlockN, binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, + binfo.actual_seqlen_q, kNWarps * 16, alibi_slope ); @@ -444,7 +444,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // if (cute::thread0()) { print(tOrP); } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration @@ -483,19 +483,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - + if (Has_alibi) { flash::apply_alibi( - scores, - n_block * kBlockN, + scores, + n_block * kBlockN, binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, + binfo.actual_seqlen_q, kNWarps * 16, alibi_slope ); } - + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { flash::apply_mask_local( scores, n_block * kBlockN, binfo.actual_seqlen_k, @@ -528,7 +528,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi block_row_idx, block_col_idx, kNWarps); } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue @@ -977,11 +977,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (Has_alibi) { flash::apply_alibi( - scores, - n_block * kBlockN, + scores, + n_block * kBlockN, binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, + binfo.actual_seqlen_q, kNWarps * 16, alibi_slope ); @@ -1027,7 +1027,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration @@ -1069,11 +1069,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (Has_alibi) { flash::apply_alibi( - scores, - n_block * kBlockN, + scores, + n_block * kBlockN, binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, + binfo.actual_seqlen_q, kNWarps * 16, alibi_slope ); @@ -1094,7 +1094,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index f000ff2..ac425d9 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - // Maybe the VtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomVtransposedNoSwizzle{}, - Shape, Int>{})); - // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( composition(Swizzle{}, @@ -247,19 +237,9 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); - using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); - using SmemLayoutKtransposed = decltype(tile_to_shape( - SmemLayoutAtomKtransposed{}, - make_shape(Int{}, Int{}))); - // Maybe the KtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomKtransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + using SmemLayoutKtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 @@ -277,30 +257,15 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); - using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); - using SmemLayoutPdStransposed = decltype(tile_to_shape( - SmemLayoutAtomPdStransposed{}, - make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomPdStransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemLayoutPdStransposed = decltype( + composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + using SmemCopyAtomPdS = Copy_Atom; - using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); - using SmemLayoutQdOtransposed = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposed{}, - make_shape(Int{}, Int{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposed = decltype( + composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index edf6a60..31023c5 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -162,9 +162,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 template -inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, - ThrCopy smem_thr_copy_B) { +inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K @@ -188,10 +188,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting - // "int_tuple.hpp(74): error: conversion to inaccessible base class" - // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); - return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -207,13 +204,9 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { static_assert(mma_shape_K == 8 || mma_shape_K == 16); constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - // TD [2023-08-13]: Same error as above on Cutlass 3.2 - // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - // get<0, 1>(l), - // get<1, 1, 1>(l)); - return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), - get<1>(get<0>(l)), - get<1>(get<1>(get<1>(l)))); + return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + get<0, 1>(l), + get<1, 1, 1>(l)); }; ////////////////////////////////////////////////////////////////////////////////////////////////////