From 8f4d82cf5e9e2772c4b7cc3f391d5ab6a8ffe9c6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 20 Jan 2024 22:30:06 -0800 Subject: [PATCH] Update cutlass to v3.4.0 --- csrc/cutlass | 2 +- csrc/flash_attn/src/flash_bwd_kernel.h | 17 +++++++++-------- csrc/flash_attn/src/flash_fwd_kernel.h | 1 - csrc/flash_attn/src/kernel_traits.h | 10 ++++------ 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/csrc/cutlass b/csrc/cutlass index a75b4ac..751eb9a 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit a75b4ac483166189a45290783cb0a18af5ff0ea5 +Subproject commit 751eb9a8859ac36bfc77551f9e4a957c31a5a8b1 diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index c8cc8fe..d2be545 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -32,11 +32,12 @@ CUTE_HOST_DEVICE auto make_tiled_copy_B_warpcontiguousN(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value; + constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value; using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; // Divide by 2 because right now we always use 2 for the ValLayout - constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2; + constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; // This gives the correct layout, idk why. // auto t = make_tile(Layout, _2>, @@ -45,7 +46,7 @@ make_tiled_copy_B_warpcontiguousN(Copy_Atom const& copy_atom, // Stride<_1, _64, _8> >{}, auto t = make_tile(Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) Stride<_1, Int, _8> >{}, // (1, 64, 8) or (1, 32, 8) - make_layout(size<2>(TileShape_MNK{}))); + make_layout(Int{})); // if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); } return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t); } @@ -59,13 +60,14 @@ CUTE_HOST_DEVICE auto make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value; + constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value; using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; // Divide by 2 because right now we always use 2 for the ValLayout - constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2; + constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; - auto t = make_tile(make_layout(size<0>(TileShape_MNK{})), + auto t = make_tile(make_layout(Int{}), Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) Stride<_1, Int, _8> >{}); // (1, 64, 8) or (1, 32, 8) // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); } @@ -90,8 +92,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; - // constexpr int kNWarps = Kernel_traits::kNWarps; - constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value; + constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 5ba68e9..416c551 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -41,7 +41,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; - constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; auto seed_offset = at::cuda::philox::unpack(params.philox_args); flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 9a2502f..cdc7608 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -32,10 +32,8 @@ struct Flash_kernel_traits { MMA_Atom, MMA_Atom >; - using ValLayoutMNK = Layout>; #else using MMA_Atom_Arch = MMA_Atom; - using ValLayoutMNK = Layout>; #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 @@ -76,7 +74,7 @@ struct Flash_fwd_kernel_traits : public Base { using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile, _16, _16>>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, @@ -197,17 +195,17 @@ struct Flash_bwd_kernel_traits : public Base { using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; using TiledMmadKV = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; using SmemLayoutAtomQdO = decltype( composition(Swizzle{},