From 3e2c827d9a1bcf5d1e49e3bda3b470815cf78e04 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 20 Jan 2024 17:41:44 -0800 Subject: [PATCH] Remove unused kernel_traits file --- csrc/flash_attn/src/kernel_traits_sm90.h | 159 ----------------------- 1 file changed, 159 deletions(-) delete mode 100644 csrc/flash_attn/src/kernel_traits_sm90.h diff --git a/csrc/flash_attn/src/kernel_traits_sm90.h b/csrc/flash_attn/src/kernel_traits_sm90.h deleted file mode 100644 index e07f383..0000000 --- a/csrc/flash_attn/src/kernel_traits_sm90.h +++ /dev/null @@ -1,159 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "cute/algorithm/copy.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/layout/layout.h" -#include - -using namespace cute; - -template -struct Flash_kernel_traits_sm90 { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using Element = elem_type; - static constexpr bool Has_cp_async = true; -#else - using Element = cutlass::half_t; - static constexpr bool Has_cp_async = false; -#endif - - using ElementAccum = float; - using index_t = uint32_t; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; - using ValLayoutMNK = Layout>; -#else - using MMA_Atom_Arch = MMA_Atom; - using ValLayoutMNK = Layout>; -#endif - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#else - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#endif -}; - -template > -struct Flash_fwd_kernel_traits : public Base { - using Element = typename Base::Element; - using ElementAccum = typename Base::ElementAccum; - using index_t = typename Base::index_t; - static constexpr bool Has_cp_async = Base::Has_cp_async; - using SmemCopyAtom = typename Base::SmemCopyAtom; - using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; - - static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; - static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; - - // The number of threads. - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - using TiledMma = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM - - using SmemLayoutAtomQ = decltype( - composition(Swizzle{}, - // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 - Layout>, - Stride, _1>>{})); - using SmemLayoutQ = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutKV = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout, Int>, - Stride<_1, Int>>{})); - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - // Maybe the VtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); - - using SmemLayoutAtomO = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( - SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - - static constexpr int kSmemQCount = size(SmemLayoutQ{}); - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. - // For example, for d=128, smem is split into 2 "pages", each page takes care of columns - // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, - // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, - // to the same banks. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using Gmem_copy_struct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy - >; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout, Int>, - Stride, _1>>; - - using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomP{}, - Layout>{})); // Val layout, 8 vals per store - -}; - -////////////////////////////////////////////////////////////////////////////////////////////////////