From 74b0761ff7efc7b90d4e5aeb529c1b2a09a7458c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 14 Jul 2024 23:39:46 -0700 Subject: [PATCH] [FA3] BF16 forward --- csrc/cutlass | 2 +- hopper/epilogue_fwd_sm90_tma.hpp | 3 +- hopper/flash.h | 2 - hopper/flash_api.cpp | 32 ++-- hopper/flash_fwd_hdim128_bf16_sm90.cu | 9 + hopper/flash_fwd_hdim256_bf16_sm90.cu | 9 + hopper/flash_fwd_hdim64_bf16_sm90.cu | 9 + hopper/flash_fwd_kernel.h | 62 +++---- hopper/flash_fwd_launch_template.h | 42 +++-- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 100 ++++------- hopper/named_barrier.hpp | 23 +++ hopper/setup.py | 3 + hopper/test_flash_attn.py | 34 ++-- hopper/tile_scheduler.hpp | 217 ++++++++++------------- 14 files changed, 278 insertions(+), 269 deletions(-) create mode 100644 hopper/flash_fwd_hdim128_bf16_sm90.cu create mode 100644 hopper/flash_fwd_hdim256_bf16_sm90.cu create mode 100644 hopper/flash_fwd_hdim64_bf16_sm90.cu create mode 100644 hopper/named_barrier.hpp diff --git a/csrc/cutlass b/csrc/cutlass index fa4f635..756c351 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit fa4f6359069bd4dd6fabd0cda2476dd8e72b3837 +Subproject commit 756c351b4994854b2f8c6dded3821ebbb580876b diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp index c63c42a..2d5c33e 100644 --- a/hopper/epilogue_fwd_sm90_tma.hpp +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -9,6 +9,7 @@ #include "cutlass/gemm/collective/collective_builder.hpp" +#include "named_barrier.hpp" #include "utils.h" namespace flash { @@ -127,7 +128,7 @@ struct CollectiveEpilogueFwd { Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Make sure all WGs have finished reading V - cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::ValueEmpty) /*id*/); cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, diff --git a/hopper/flash.h b/hopper/flash.h index c61ffaa..0418f0b 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -66,8 +66,6 @@ struct Flash_fwd_params : public Qkv_params { // The dimensions. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; - cutlass::FastDivmod head_divmod, m_block_divmod; - int total_blocks; // The scaling factors for the kernel. float scale_softmax; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 3dca4bf..f21d2d1 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -99,8 +99,6 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.d = d; params.d_rounded = d_rounded; - params.head_divmod = cutlass::FastDivmod(int(h)); - // Set the different scale values. params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; @@ -225,12 +223,22 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split // run_mha_fwd_(params, stream); // }); if (!params.is_e4m3) { - if (params.d == 64) { - run_mha_fwd_(params, stream); - } else if (params.d == 128) { - run_mha_fwd_(params, stream); + if (params.is_bf16) { + if (params.d == 64) { + run_mha_fwd_(params, stream); + } else if (params.d == 128) { + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_(params, stream); + } } else { - run_mha_fwd_(params, stream); + if (params.d == 64) { + run_mha_fwd_(params, stream); + } else if (params.d == 128) { + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_(params, stream); + } } } else { // run_mha_fwd_(params, stream); @@ -250,9 +258,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); auto q_dtype = q.dtype(); - // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - TORCH_CHECK(q_dtype == torch::kFloat16, - "FlashAttention only support fp16 data type for now"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type for now"); // TODO: will add e4m3 later // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn, // "FlashAttention only support fp16 and bf16 data type"); @@ -278,10 +285,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const int head_size_og = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - TORCH_CHECK(num_heads == num_heads_k, "We do not support MQA/GQA yet"); TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now"); @@ -345,7 +351,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size /*window_size_left=*/-1, /*window_size_right=*/is_causal ? 0 : -1); - auto tile_count_semaphore = is_causal ? torch::full({1}, 132, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); + auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); params.tile_count_semaphore = tile_count_semaphore.data_ptr(); if (seqlen_k > 0) { diff --git a/hopper/flash_fwd_hdim128_bf16_sm90.cu b/hopper/flash_fwd_hdim128_bf16_sm90.cu new file mode 100644 index 0000000..11bb9dd --- /dev/null +++ b/hopper/flash_fwd_hdim128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_bf16_sm90.cu b/hopper/flash_fwd_hdim256_bf16_sm90.cu new file mode 100644 index 0000000..06d0df6 --- /dev/null +++ b/hopper/flash_fwd_hdim256_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_bf16_sm90.cu b/hopper/flash_fwd_hdim64_bf16_sm90.cu new file mode 100644 index 0000000..d383989 --- /dev/null +++ b/hopper/flash_fwd_hdim64_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h index 70871fc..b97250d 100644 --- a/hopper/flash_fwd_kernel.h +++ b/hopper/flash_fwd_kernel.h @@ -26,8 +26,7 @@ using namespace cute; template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - compute_attn_ws(CUTE_GRID_CONSTANT Flash_fwd_params const params, - CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params ) { @@ -101,9 +100,6 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); // cutlass::arch::warpgroup_reg_dealloc<56>(); - // StaticPersistentTileScheduler scheduler{params.m_block_divmod, params.head_divmod, params.total_blocks}; - // auto work_tile_info = scheduler.get_current_work(); - TileScheduler scheduler; int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); if (warp_idx_in_warpgroup == 0) { // Load Q, K, V @@ -112,20 +108,22 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, int work_idx = 0; - // auto get_tile_count = [&] () { - // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/); - // return shared_storage.tile_count_semaphore; - // }; + TileScheduler scheduler(&shared_storage.tile_count_semaphore); + for (auto work_tile_info = scheduler.get_initial_work(); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [m_block, bidh, bidb] = block_coord; - // while (work_tile_info.is_valid()) { - // for (int tile_count = blockIdx.x; tile_count < params.total_blocks; tile_count = get_tile_count()) { - // for (int tile_count_semaphore = blockIdx.x; tile_count_semaphore < params.total_blocks; tile_count_semaphore = __shfl_sync(0xffffffff, tile_count_semaphore, 0)) { - for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) { - int tile_count_semaphore = 0; - collective_mainloop.load(params, mainloop_params, scheduler_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, - shared_storage, work_tile_info, work_idx, tile_count_semaphore); - // ++work_idx; - // work_tile_info = scheduler.fetch_next_work(); + int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); + if (Is_causal && n_block_max <= 0) { + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + continue; + } + collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); + ++work_idx; } collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); } @@ -133,44 +131,31 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, cutlass::arch::warpgroup_reg_alloc(); // cutlass::arch::warpgroup_reg_alloc(); + TileScheduler scheduler(&shared_storage.tile_count_semaphore); // Initialize matmul objects. typename Ktraits::TiledMma1 tiled_mma1; - TileScheduler scheduler{}; - PipelineState smem_pipe_read_k, smem_pipe_read_v; - // We don't need separate variables smem_pip_release_k and smem_pipe_release_v + // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v // (like in Cutlass's gemm) because the read and release pipeline states are always the same. - auto get_tile_count = [&] () { - // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/); - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/); - return shared_storage.tile_count_semaphore; - }; - collective_mainloop.mma_init(); + scheduler.init_consumer(); int work_idx = 0; CUTLASS_PRAGMA_NO_UNROLL - // for (int work_idx = 0; work_idx * gridDim.x + blockIdx.x < params.total_blocks; ++work_idx) { - // for (int tile_count_semaphore = blockIdx.x, work_idx = 0; tile_count_semaphore < params.total_blocks; tile_count_semaphore = get_tile_count()) { - for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) { + for (auto work_tile_info = scheduler.get_initial_work(); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { // Attention output (GEMM-II) accumulator. Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax; - // int m_block; - // int bidh, bidb; - // // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_idx * gridDim.x + blockIdx.x)); - // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore)); - // cute::tuple block_coord = {m_block, bidh, bidb}; auto block_coord = work_tile_info.get_block_coord(scheduler_params); auto [m_block, bidh, bidb] = block_coord; int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. - // Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/); collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord); continue; } @@ -178,15 +163,14 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage); // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); - // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, 0, shared_storage); collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, threadIdx.x - NumCopyThreads, block_coord); ++work_idx; - // work_tile_info = scheduler.fetch_next_work(); } collective_epilogue.store_tail(); } + } } // namespace flash diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index ff67db5..e0b40ce 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -8,6 +8,7 @@ #include "cute/tensor.hpp" +#include "cutlass/cutlass.h" #include "cutlass/cluster_launch.hpp" #include "static_switch.h" @@ -26,8 +27,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{}); using CollectiveMainloop = flash::CollectiveMainloopFwd; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; - // using Scheduler = flash::SingleTileScheduler; - using Scheduler = flash::StaticPersistentTileScheduler; + using Scheduler = std::conditional_t>; + // flash::SingleTileScheduler>; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ static_cast(params.q_ptr), @@ -51,32 +54,35 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); - typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b}; + typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore}; typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); // Get the ptr to kernel function. void *kernel; kernel = (void *)flash::compute_attn_ws; int smem_size = sizeof(typename Kernel_traits::SharedStorage); - int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); - int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); - int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v)); + // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); + // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); + // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v)); // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v); if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } + int device; + cudaGetDevice(&device); + int multiprocessor_count; + cudaError status_ = cudaDeviceGetAttribute( + &multiprocessor_count, cudaDevAttrMultiProcessorCount, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); static constexpr int ctaSize = Kernel_traits::kNWarps * 32; - params.m_block_divmod = cutlass::FastDivmod(num_blocks_m); - params.total_blocks = num_blocks_m * params.h * params.b; - // dim3 grid_dims(num_blocks_m, params.h, params.b); - // dim3 grid_dims(132); - dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, 132); dim3 block_dims(ctaSize); dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; - cutlass::launch_kernel_on_cluster(launch_params, kernel, params, mainloop_params, epilogue_params, scheduler_params); - // kernel<<>>(params, tma_load_Q, tma_load_K, tma_load_V, tma_store_O); + cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -92,7 +98,10 @@ template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_causal>(params, stream); + // Only use Cluster if number of tiles along seqlen_q is even + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); }); } @@ -100,6 +109,9 @@ template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_causal>(params, stream); + // Only use Cluster if number of tiles along seqlen_q is even + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); }); } diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 9a8521e..f9dc94a 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -14,6 +14,7 @@ #include "cutlass/gemm/collective/collective_builder.hpp" +#include "named_barrier.hpp" #include "utils.h" namespace flash { @@ -108,6 +109,7 @@ struct CollectiveMainloopFwd { struct Params { ShapeQKV const shape_Q; ShapeQKV const shape_K; + cutlass::FastDivmod qhead_per_khead_divmod; TMA_Q tma_load_Q; TMA_KV tma_load_K, tma_load_V; float const softmax_scale_log2; @@ -137,7 +139,10 @@ struct CollectiveMainloopFwd { SmemLayoutV{}(_, _, _0{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return {args.shape_Q, args.shape_K, tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2}; + return {args.shape_Q, args.shape_K, + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + tma_load_Q, tma_load_K, tma_load_V, + args.softmax_scale_log2}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -162,46 +167,21 @@ struct CollectiveMainloopFwd { return n_block_max; } - template + template CUTLASS_DEVICE void - load(FullParams const& params, - Params const& mainloop_params, - SchedulerParams const& scheduler_params, + load(Params const& mainloop_params, MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v, SharedStorage &shared_storage, - WorkTileInfo work_tile_info, - int& work_idx, - int& tile_count_semaphore + Scheduler& scheduler, + typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + cute::tuple block_coord, + int work_idx ) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - - // int const m_block = work_tile_info.M_idx; - // int const bidh = work_tile_info.H_idx; - // int const bidb = work_tile_info.B_idx; - // int m_block; - // int bidh, bidb; - // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore)); - auto [m_block, bidh, bidb] = work_tile_info.get_block_coord(scheduler_params); - // if (threadIdx.x == 0) { printf("producer, blockIdx.x = %d, bidb = %d, bidh = %d, m_block = %d\n", blockIdx.x, bidb, bidh, m_block); } - - int n_block_max = get_n_block_max(mainloop_params, m_block); - if (Is_causal && n_block_max <= 0) { - // Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/); - // if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { - // tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1); - // shared_storage.tile_count_semaphore = tile_count_semaphore; - // } - // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/); - return; - } - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); @@ -210,13 +190,16 @@ struct CollectiveMainloopFwd { Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K); Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K); + auto [m_block, bidh, bidb] = block_coord; + int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); + // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); @@ -235,6 +218,7 @@ struct CollectiveMainloopFwd { } } + int n_block_max = get_n_block_max(mainloop_params, m_block); int n_block = n_block_max - 1; int lane_predicate = cute::elect_one_sync(); @@ -246,7 +230,7 @@ struct CollectiveMainloopFwd { } // Wait for the MMA warpgroups to say that smem_q is ready - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if (lane_predicate) { shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); @@ -272,22 +256,14 @@ struct CollectiveMainloopFwd { ++smem_pipe_write_v; } } - if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { - // tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1); - } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); if (lane_predicate) { pipeline_v.producer_acquire(smem_pipe_write_v); copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); ++smem_pipe_write_v; } - if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { - // printf("blockIdx.x = %d, tile_count_semaphore: %d\n", blockIdx.x, tile_count_semaphore); - // shared_storage.tile_count_semaphore = tile_count_semaphore; - } - // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/); - ++work_idx; + scheduler.broadcast_next_work(work_tile_info); } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster @@ -307,36 +283,36 @@ struct CollectiveMainloopFwd { } CUTLASS_DEVICE void - scheduler_barrier_sync() { + warp_scheduler_barrier_sync() { if constexpr (UseSchedulerBarrier) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, 3 + cutlass::canonical_warp_group_idx() /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/); } } CUTLASS_DEVICE void - scheduler_barrier_arrive() { + warp_scheduler_barrier_arrive() { if constexpr (!UseSchedulerBarrier) { return; } static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + (3 - cutlass::canonical_warp_group_idx()) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/); } else { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 2 ? 3 + cutlass::canonical_warp_group_idx() + 1 : 3 + cutlass::canonical_warp_group_idx() + 1 - 3 /*id*/); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 1 ? 3 + cutlass::canonical_warp_group_idx() + 2 : 3 + cutlass::canonical_warp_group_idx() + 2 - 3 /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); } } CUTLASS_DEVICE void mma_init() { // Tell producer (warp 0) that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (!UseSchedulerBarrier) { return; } static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); if (cutlass::canonical_warp_group_idx() > 1) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 1 /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/); } if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) { if (cutlass::canonical_warp_group_idx() > 2) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 2 /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/); } } @@ -393,9 +369,9 @@ struct CollectiveMainloopFwd { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); - scheduler_barrier_sync(); + warp_scheduler_barrier_sync(); flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - scheduler_barrier_arrive(); + warp_scheduler_barrier_arrive(); if (work_idx != 0) { int lane_predicate = cute::elect_one_sync(); if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { @@ -443,12 +419,12 @@ struct CollectiveMainloopFwd { for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); - scheduler_barrier_sync(); + warp_scheduler_barrier_sync(); flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); } consumer_wait(pipeline_v, smem_pipe_read_v); flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - scheduler_barrier_arrive(); + warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read_k); // release K Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); @@ -472,12 +448,12 @@ struct CollectiveMainloopFwd { for (; n_block > 0; --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read_k); - scheduler_barrier_sync(); + warp_scheduler_barrier_sync(); flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); softmax.rescale_o(tOrO, scores_scale); consumer_wait(pipeline_v, smem_pipe_read_v); flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - scheduler_barrier_arrive(); + warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read_k); // release K // auto scores_scale = softmax.template max(tSrS); @@ -491,7 +467,7 @@ struct CollectiveMainloopFwd { cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); } // Tell warp 0 that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); softmax.rescale_o(tOrO, scores_scale); consumer_wait(pipeline_v, smem_pipe_read_v); flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp new file mode 100644 index 0000000..ac5f616 --- /dev/null +++ b/hopper/named_barrier.hpp @@ -0,0 +1,23 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/arch/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts +enum class FwdNamedBarriers { + QueryEmpty = 0, + ValueEmpty = 1, + TileCountSmemEmpty = 2, + TileCountSmemFull = 3, + WarpSchedulerWG1 = 4, + WarpSchedulerWG2 = 5, + WarpSchedulerWG3 = 6, +}; + +} // flash \ No newline at end of file diff --git a/hopper/setup.py b/hopper/setup.py index 6d7a4c3..35a074a 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -111,8 +111,11 @@ if not SKIP_CUDA_BUILD: sources = [ "flash_api.cpp", "flash_fwd_hdim64_fp16_sm90.cu", + "flash_fwd_hdim64_bf16_sm90.cu", "flash_fwd_hdim128_fp16_sm90.cu", + "flash_fwd_hdim128_bf16_sm90.cu", "flash_fwd_hdim256_fp16_sm90.cu", + "flash_fwd_hdim256_bf16_sm90.cu", "flash_bwd_hdim64_fp16_sm90.cu", "flash_bwd_hdim128_fp16_sm90.cu", "flash_bwd_hdim256_fp16_sm90.cu", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 026d346..97852d4 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -131,15 +131,18 @@ def attention_ref( -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["gqa"]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @pytest.mark.parametrize("d", [64, 128, 256]) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [256]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -151,6 +154,8 @@ def attention_ref( (113, 211), (108, 256), (256, 512), + (384, 256), + (640, 128), (512, 256), (1024, 1024), (1023, 1024), @@ -160,7 +165,7 @@ def attention_ref( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, dtype + seqlen_q, seqlen_k, d, causal, mha_type, dtype ): device = "cuda" # set seed @@ -168,16 +173,13 @@ def test_flash_attn_output( # batch_size = 40 # nheads = 16 batch_size = 9 - nheads = 4 + nheads = 6 + nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # batch_size = 1 # nheads = 1 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) - k = torch.randn( - batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True - ) - v = torch.randn( - batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True - ) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) out, lse = flash_attn_func(q, k, v, causal=causal) out_ref, attn_ref = attention_ref( q, @@ -202,15 +204,15 @@ def test_flash_attn_output( # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) - qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) - lse_ref = torch.logsumexp(qk, dim=-1) + # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - if not causal: - print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # if d <= 128: @@ -248,5 +250,3 @@ def test_flash_attn_output( # assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() # assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() # assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() - - diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index d20f214..b27f179 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -1,112 +1,26 @@ /****************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cutlass/fast_math.h" +#include "cutlass/arch/barrier.h" + +#include "named_barrier.hpp" namespace flash { /////////////////////////////////////////////////////////////////////////////// -class StaticPersistentTileSchedulerOld { - // - // Data members - // - -private: - int current_work_linear_idx_; - cutlass::FastDivmod const &m_block_divmod, &head_divmod; - int const total_blocks; - -public: - struct WorkTileInfo { - int M_idx = 0; - int H_idx = 0; - int B_idx = 0; - bool is_valid_tile = false; - - CUTLASS_HOST_DEVICE - bool - is_valid() const { - return is_valid_tile; - } - - CUTLASS_HOST_DEVICE - static WorkTileInfo - invalid_work_tile() { - return {-1, -1, -1, false}; - } - - }; - -public: - - CUTLASS_DEVICE explicit StaticPersistentTileSchedulerOld(cutlass::FastDivmod const &m_block_divmod_, - cutlass::FastDivmod const &head_divmod_, - int const total_blocks_) : - m_block_divmod(m_block_divmod_), head_divmod(head_divmod_), total_blocks(total_blocks_) { - - // MSVC requires protecting use of CUDA-specific nonstandard syntax, - // like blockIdx and gridDim, with __CUDA_ARCH__. -#if defined(__CUDA_ARCH__) - // current_work_linear_idx_ = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; - current_work_linear_idx_ = blockIdx.x; -#else - CUTLASS_ASSERT(false && "This line should never be reached"); -#endif - } - - CUTLASS_DEVICE - WorkTileInfo - get_current_work() const { - return get_current_work_for_linear_idx(current_work_linear_idx_); - } - - CUTLASS_DEVICE - WorkTileInfo - get_current_work_for_linear_idx(int linear_idx) const { - if (linear_idx >= total_blocks) { - return WorkTileInfo::invalid_work_tile(); - } - - // Map worker's linear index into the CTA tiled problem shape to the corresponding MHB indices - int M_idx, H_idx, B_idx; - int quotient = m_block_divmod.divmod(M_idx, linear_idx); - B_idx = head_divmod.divmod(H_idx, quotient); - return {M_idx, H_idx, B_idx, true}; - } - - CUTLASS_DEVICE - void - // advance_to_next_work(int advance_count = 1) { - advance_to_next_work() { - // current_work_linear_idx_ += int(gridDim.x * gridDim.y * gridDim.z); - current_work_linear_idx_ += int(gridDim.x); - } - - CUTLASS_DEVICE - WorkTileInfo - fetch_next_work() { - WorkTileInfo new_work_tile_info; - advance_to_next_work(); - new_work_tile_info = get_current_work(); - return new_work_tile_info; - } - -}; - -/////////////////////////////////////////////////////////////////////////////// - -class SingleTileScheduler { +struct SingleTileScheduler { public: // Host side kernel arguments struct Arguments { int const num_blocks_m, num_head, num_batch; - int const* tile_count_semaphore = nullptr; + int* const tile_count_semaphore = nullptr; }; // Device side kernel params @@ -140,20 +54,30 @@ public: return {M_idx, H_idx, B_idx}; } - CUTLASS_DEVICE - WorkTileInfo - get_next_work(Params const& params) const { - return {-1, -1, -1, false}; - } - }; + CUTLASS_DEVICE + SingleTileScheduler(int* tile_count_smem_) { } + CUTLASS_DEVICE WorkTileInfo get_initial_work() const { return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true}; } + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const {} + + template CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { @@ -171,7 +95,7 @@ public: // Host side kernel arguments struct Arguments { int const num_blocks_m, num_head, num_batch; - int const* tile_count_semaphore = nullptr; + int* const tile_count_semaphore = nullptr; }; // Device side kernel params @@ -210,12 +134,28 @@ public: }; + CUTLASS_DEVICE + StaticPersistentTileScheduler(int* tile_count_smem_) {}; + CUTLASS_DEVICE WorkTileInfo get_initial_work() const { return {int(blockIdx.x)}; } + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const {} + + template CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { @@ -224,21 +164,25 @@ public: }; +template class DynamicPersistentTileScheduler { +protected: + int* const tile_count_smem; + public: // Host side kernel arguments struct Arguments { int const num_blocks_m, num_head, num_batch; - int const* tile_count_semaphore; + int* const tile_count_semaphore; }; // Device side kernel params struct Params { int const total_blocks; cutlass::FastDivmod const m_block_divmod, head_divmod; - int const* tile_count_semaphore; + int* const tile_count_semaphore; }; static Params @@ -253,25 +197,27 @@ public: return {uint32_t(num_sm)}; } - using WorkTileInfo = StaticPersistentTileScheduler::WorkTileInfo; - // struct WorkTileInfo { - // int tile_idx; + struct WorkTileInfo { + int tile_idx; - // CUTLASS_DEVICE - // bool - // is_valid(Params const& params) const { - // return tile_idx < params.total_blocks; - // } + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } - // CUTLASS_DEVICE - // cute::tuple - // get_block_coord(Params const& params) const { - // int m_block, bidh, bidb; - // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx)); - // return {m_block, bidh, bidb}; - // } + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int m_block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx)); + return {m_block, bidh, bidb}; + } - // }; + }; + + CUTLASS_DEVICE + DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {}; CUTLASS_DEVICE WorkTileInfo @@ -279,12 +225,45 @@ public: return {int(blockIdx.x)}; } + CUTLASS_DEVICE + void + init_consumer() const { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const { + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + *tile_count_smem = current_work.tile_idx; + } + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + } + + template CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {current_work.tile_idx + int(gridDim.x)}; + if constexpr (IsProducer) { + // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 + return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)}; + } else { + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + int tile_idx = *tile_count_smem; + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + return {tile_idx}; + } } }; -} // flash +} // flash \ No newline at end of file