diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index ff86d01..3051eab 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -2,7 +2,9 @@ * Copyright (c) 2023, Tri Dao. ******************************************************************************/ -#include +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include +#include #include #include diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index ff49cb8..0dbf928 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -13,8 +13,7 @@ #include #endif -#include - +#include // For at::cuda::philox::unpack constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index fc5724c..6bece9b 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -5,18 +5,15 @@ #pragma once #include -#include #include #include #include -#include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" -#include "philox.cuh" namespace flash { diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index c0e3df5..3697d21 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -4,20 +4,16 @@ #pragma once -#include #include -#include #include #include #include -#include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" -#include "philox.cuh" namespace flash { @@ -25,49 +21,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE -auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(Layout, Int>, - Stride<_1, Int> >{}, - // TODO: Shouldn't this be size<1>? - make_layout(size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, Tensor2 &acc_o, float softmax_scale_log2) { @@ -256,7 +209,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} @@ -558,7 +510,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Partition sO to match the accumulator partitioning auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); - // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 71645fa..1712453 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -76,7 +76,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index f72313a..987f5ef 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -8,8 +8,7 @@ #include -#include -#include +#include #include "philox.cuh" #include "utils.h" diff --git a/setup.py b/setup.py index 21cdc89..5b4beff 100644 --- a/setup.py +++ b/setup.py @@ -189,7 +189,7 @@ if not SKIP_CUDA_BUILD: "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - "--ptxas-options=-v", + # "--ptxas-options=-v", # "--ptxas-options=-O2", "-lineinfo" ]