Remove some unused headers
This commit is contained in:
parent
08c295c043
commit
bb9beb3645
@ -2,7 +2,9 @@
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include <torch/extension.h>
|
||||
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
|
||||
#include <torch/python.h>
|
||||
#include <torch/nn/functional.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
|
||||
@ -13,8 +13,7 @@
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
|
||||
@ -5,18 +5,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
#include "block_info.h"
|
||||
#include "kernel_traits.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
#include "philox.cuh"
|
||||
|
||||
namespace flash {
|
||||
|
||||
|
||||
@ -4,20 +4,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
#include "block_info.h"
|
||||
#include "kernel_traits.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
#include "philox.cuh"
|
||||
|
||||
namespace flash {
|
||||
|
||||
@ -25,49 +21,6 @@ using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_M,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||
Stride<_1, Int<MMAStride_M>> >{},
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_M,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||
Stride<_1, Int<MMAStride_M>> >{},
|
||||
// TODO: Shouldn't this be size<1>?
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||
@ -256,7 +209,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
|
||||
@ -558,7 +510,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// Partition sO to match the accumulator partitioning
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
|
||||
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
|
||||
@ -76,7 +76,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
||||
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
|
||||
@ -8,8 +8,7 @@
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "philox.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user