[FA3] BF16 forward

This commit is contained in:
Tri Dao 2024-07-14 23:39:46 -07:00
parent 898dd4bbf2
commit 74b0761ff7
14 changed files with 278 additions and 269 deletions

@ -1 +1 @@
Subproject commit fa4f6359069bd4dd6fabd0cda2476dd8e72b3837
Subproject commit 756c351b4994854b2f8c6dded3821ebbb580876b

View File

@ -9,6 +9,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "named_barrier.hpp"
#include "utils.h"
namespace flash {
@ -127,7 +128,7 @@ struct CollectiveEpilogueFwd {
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// Make sure all WGs have finished reading V
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/);
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,

View File

@ -66,8 +66,6 @@ struct Flash_fwd_params : public Qkv_params {
// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
cutlass::FastDivmod head_divmod, m_block_divmod;
int total_blocks;
// The scaling factors for the kernel.
float scale_softmax;

View File

@ -99,8 +99,6 @@ void set_params_fprop(Flash_fwd_params &params,
params.d = d;
params.d_rounded = d_rounded;
params.head_divmod = cutlass::FastDivmod(int(h));
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
@ -225,12 +223,22 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
// run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
// });
if (!params.is_e4m3) {
if (params.d == 64) {
run_mha_fwd_<cutlass::half_t, 64>(params, stream);
} else if (params.d == 128) {
run_mha_fwd_<cutlass::half_t, 128>(params, stream);
if (params.is_bf16) {
if (params.d == 64) {
run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
} else if (params.d == 128) {
run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
} else {
run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
}
} else {
run_mha_fwd_<cutlass::half_t, 256>(params, stream);
if (params.d == 64) {
run_mha_fwd_<cutlass::half_t, 64>(params, stream);
} else if (params.d == 128) {
run_mha_fwd_<cutlass::half_t, 128>(params, stream);
} else {
run_mha_fwd_<cutlass::half_t, 256>(params, stream);
}
}
} else {
// run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
@ -250,9 +258,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
auto q_dtype = q.dtype();
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
TORCH_CHECK(q_dtype == torch::kFloat16,
"FlashAttention only support fp16 data type for now");
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type for now");
// TODO: will add e4m3 later
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
// "FlashAttention only support fp16 and bf16 data type");
@ -278,10 +285,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads == num_heads_k, "We do not support MQA/GQA yet");
TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
@ -345,7 +351,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
/*window_size_left=*/-1,
/*window_size_right=*/is_causal ? 0 : -1);
auto tile_count_semaphore = is_causal ? torch::full({1}, 132, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
if (seqlen_k > 0) {

View File

@ -0,0 +1,9 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,9 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,9 @@
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
}

View File

@ -26,8 +26,7 @@ using namespace cute;
template <typename Ktraits, bool Is_causal, typename TileScheduler>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
compute_attn_ws(CUTE_GRID_CONSTANT Flash_fwd_params const params,
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits>::Params const epilogue_params,
CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params
) {
@ -101,9 +100,6 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
// cutlass::arch::warpgroup_reg_dealloc<56>();
// StaticPersistentTileScheduler scheduler{params.m_block_divmod, params.head_divmod, params.total_blocks};
// auto work_tile_info = scheduler.get_current_work();
TileScheduler scheduler;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
@ -112,20 +108,22 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
int work_idx = 0;
// auto get_tile_count = [&] () {
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
// return shared_storage.tile_count_semaphore;
// };
TileScheduler scheduler(&shared_storage.tile_count_semaphore);
for (auto work_tile_info = scheduler.get_initial_work();
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;
// while (work_tile_info.is_valid()) {
// for (int tile_count = blockIdx.x; tile_count < params.total_blocks; tile_count = get_tile_count()) {
// for (int tile_count_semaphore = blockIdx.x; tile_count_semaphore < params.total_blocks; tile_count_semaphore = __shfl_sync(0xffffffff, tile_count_semaphore, 0)) {
for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
int tile_count_semaphore = 0;
collective_mainloop.load(params, mainloop_params, scheduler_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
shared_storage, work_tile_info, work_idx, tile_count_semaphore);
// ++work_idx;
// work_tile_info = scheduler.fetch_next_work();
int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
if (Is_causal && n_block_max <= 0) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
continue;
}
collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
++work_idx;
}
collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
}
@ -133,44 +131,31 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 240 : 160>();
// cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 224 : 160>();
TileScheduler scheduler(&shared_storage.tile_count_semaphore);
// Initialize matmul objects.
typename Ktraits::TiledMma1 tiled_mma1;
TileScheduler scheduler{};
PipelineState smem_pipe_read_k, smem_pipe_read_v;
// We don't need separate variables smem_pip_release_k and smem_pipe_release_v
// We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
auto get_tile_count = [&] () {
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
return shared_storage.tile_count_semaphore;
};
collective_mainloop.mma_init();
scheduler.init_consumer();
int work_idx = 0;
CUTLASS_PRAGMA_NO_UNROLL
// for (int work_idx = 0; work_idx * gridDim.x + blockIdx.x < params.total_blocks; ++work_idx) {
// for (int tile_count_semaphore = blockIdx.x, work_idx = 0; tile_count_semaphore < params.total_blocks; tile_count_semaphore = get_tile_count()) {
for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
for (auto work_tile_info = scheduler.get_initial_work();
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
// Attention output (GEMM-II) accumulator.
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
// int m_block;
// int bidh, bidb;
// // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_idx * gridDim.x + blockIdx.x));
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
// cute::tuple<int32_t, int32_t, int32_t> block_coord = {m_block, bidh, bidb};
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [m_block, bidh, bidb] = block_coord;
int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord);
continue;
}
@ -178,15 +163,14 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, 0, shared_storage);
collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord);
++work_idx;
// work_tile_info = scheduler.fetch_next_work();
}
collective_epilogue.store_tail();
}
}
} // namespace flash

View File

@ -8,6 +8,7 @@
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/cluster_launch.hpp"
#include "static_switch.h"
@ -26,8 +27,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal>;
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits>;
// using Scheduler = flash::SingleTileScheduler;
using Scheduler = flash::StaticPersistentTileScheduler;
using Scheduler = std::conditional_t<!Is_causal,
flash::StaticPersistentTileScheduler,
flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>>;
// flash::SingleTileScheduler>;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr),
@ -51,32 +54,35 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b};
typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore};
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
// Get the ptr to kernel function.
void *kernel;
kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
// int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
// int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
int device;
cudaGetDevice(&device);
int multiprocessor_count;
cudaError status_ = cudaDeviceGetAttribute(
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
params.m_block_divmod = cutlass::FastDivmod(num_blocks_m);
params.total_blocks = num_blocks_m * params.h * params.b;
// dim3 grid_dims(num_blocks_m, params.h, params.b);
// dim3 grid_dims(132);
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, 132);
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, mainloop_params, epilogue_params, scheduler_params);
// kernel<<<grid_dims, block_dims, smem_size, stream>>>(params, tma_load_Q, tma_load_K, tma_load_V, tma_store_O);
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
@ -92,7 +98,10 @@ template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
});
});
}
@ -100,6 +109,9 @@ template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
});
});
}

View File

@ -14,6 +14,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "named_barrier.hpp"
#include "utils.h"
namespace flash {
@ -108,6 +109,7 @@ struct CollectiveMainloopFwd {
struct Params {
ShapeQKV const shape_Q;
ShapeQKV const shape_K;
cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V;
float const softmax_scale_log2;
@ -137,7 +139,10 @@ struct CollectiveMainloopFwd {
SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.shape_Q, args.shape_K, tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2};
return {args.shape_Q, args.shape_K,
cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
tma_load_Q, tma_load_K, tma_load_V,
args.softmax_scale_log2};
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@ -162,46 +167,21 @@ struct CollectiveMainloopFwd {
return n_block_max;
}
template <typename FullParams, typename SchedulerParams, typename SharedStorage, typename WorkTileInfo>
template <typename Scheduler, typename SharedStorage>
CUTLASS_DEVICE void
load(FullParams const& params,
Params const& mainloop_params,
SchedulerParams const& scheduler_params,
load(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v,
SharedStorage &shared_storage,
WorkTileInfo work_tile_info,
int& work_idx,
int& tile_count_semaphore
Scheduler& scheduler,
typename Scheduler::Params const& scheduler_params,
typename Scheduler::WorkTileInfo& work_tile_info,
cute::tuple<int32_t, int32_t, int32_t> block_coord,
int work_idx
) {
static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{});
// int const m_block = work_tile_info.M_idx;
// int const bidh = work_tile_info.H_idx;
// int const bidb = work_tile_info.B_idx;
// int m_block;
// int bidh, bidb;
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
auto [m_block, bidh, bidb] = work_tile_info.get_block_coord(scheduler_params);
// if (threadIdx.x == 0) { printf("producer, blockIdx.x = %d, bidb = %d, bidh = %d, m_block = %d\n", blockIdx.x, bidb, bidh, m_block); }
int n_block_max = get_n_block_max(mainloop_params, m_block);
if (Is_causal && n_block_max <= 0) {
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
// if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
// tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
// shared_storage.tile_count_semaphore = tile_count_semaphore;
// }
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
return;
}
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
@ -210,13 +190,16 @@ struct CollectiveMainloopFwd {
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K);
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K);
auto [m_block, bidh, bidb] = block_coord;
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
// Prepare the TMA loads
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
@ -235,6 +218,7 @@ struct CollectiveMainloopFwd {
}
}
int n_block_max = get_n_block_max(mainloop_params, m_block);
int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync();
@ -246,7 +230,7 @@ struct CollectiveMainloopFwd {
}
// Wait for the MMA warpgroups to say that smem_q is ready
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
@ -272,22 +256,14 @@ struct CollectiveMainloopFwd {
++smem_pipe_write_v;
}
}
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
// tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
}
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
// printf("blockIdx.x = %d, tile_count_semaphore: %d\n", blockIdx.x, tile_count_semaphore);
// shared_storage.tile_count_semaphore = tile_count_semaphore;
}
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
++work_idx;
scheduler.broadcast_next_work(work_tile_info);
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
@ -307,36 +283,36 @@ struct CollectiveMainloopFwd {
}
CUTLASS_DEVICE void
scheduler_barrier_sync() {
warp_scheduler_barrier_sync() {
if constexpr (UseSchedulerBarrier) {
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 3 + cutlass::canonical_warp_group_idx() /*id*/);
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
}
}
CUTLASS_DEVICE void
scheduler_barrier_arrive() {
warp_scheduler_barrier_arrive() {
if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
} else {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 2 ? 3 + cutlass::canonical_warp_group_idx() + 1 : 3 + cutlass::canonical_warp_group_idx() + 1 - 3 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 1 ? 3 + cutlass::canonical_warp_group_idx() + 2 : 3 + cutlass::canonical_warp_group_idx() + 2 - 3 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
}
}
CUTLASS_DEVICE void
mma_init() {
// Tell producer (warp 0) that smem_q is ready
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
if constexpr (!UseSchedulerBarrier) { return; }
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if (cutlass::canonical_warp_group_idx() > 1) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 1 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
}
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
if (cutlass::canonical_warp_group_idx() > 2) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 2 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
}
}
@ -393,9 +369,9 @@ struct CollectiveMainloopFwd {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
scheduler_barrier_sync();
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
scheduler_barrier_arrive();
warp_scheduler_barrier_arrive();
if (work_idx != 0) {
int lane_predicate = cute::elect_one_sync();
if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
@ -443,12 +419,12 @@ struct CollectiveMainloopFwd {
for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
scheduler_barrier_sync();
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
consumer_wait(pipeline_v, smem_pipe_read_v);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
scheduler_barrier_arrive();
warp_scheduler_barrier_arrive();
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
@ -472,12 +448,12 @@ struct CollectiveMainloopFwd {
for (; n_block > 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
scheduler_barrier_sync();
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
scheduler_barrier_arrive();
warp_scheduler_barrier_arrive();
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
// auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
@ -491,7 +467,7 @@ struct CollectiveMainloopFwd {
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
}
// Tell warp 0 that smem_q is ready
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);

23
hopper/named_barrier.hpp Normal file
View File

@ -0,0 +1,23 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cutlass/arch/barrier.h"
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
enum class FwdNamedBarriers {
QueryEmpty = 0,
ValueEmpty = 1,
TileCountSmemEmpty = 2,
TileCountSmemFull = 3,
WarpSchedulerWG1 = 4,
WarpSchedulerWG2 = 5,
WarpSchedulerWG3 = 6,
};
} // flash

View File

@ -111,8 +111,11 @@ if not SKIP_CUDA_BUILD:
sources = [
"flash_api.cpp",
"flash_fwd_hdim64_fp16_sm90.cu",
"flash_fwd_hdim64_bf16_sm90.cu",
"flash_fwd_hdim128_fp16_sm90.cu",
"flash_fwd_hdim128_bf16_sm90.cu",
"flash_fwd_hdim256_fp16_sm90.cu",
"flash_fwd_hdim256_bf16_sm90.cu",
"flash_bwd_hdim64_fp16_sm90.cu",
"flash_bwd_hdim128_fp16_sm90.cu",
"flash_bwd_hdim256_fp16_sm90.cu",

View File

@ -131,15 +131,18 @@ def attention_ref(
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["gqa"])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
@pytest.mark.parametrize("d", [64, 128, 256])
# @pytest.mark.parametrize("d", [128])
# @pytest.mark.parametrize("d", [256])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
@ -151,6 +154,8 @@ def attention_ref(
(113, 211),
(108, 256),
(256, 512),
(384, 256),
(640, 128),
(512, 256),
(1024, 1024),
(1023, 1024),
@ -160,7 +165,7 @@ def attention_ref(
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
def test_flash_attn_output(
seqlen_q, seqlen_k, d, causal, dtype
seqlen_q, seqlen_k, d, causal, mha_type, dtype
):
device = "cuda"
# set seed
@ -168,16 +173,13 @@ def test_flash_attn_output(
# batch_size = 40
# nheads = 16
batch_size = 9
nheads = 4
nheads = 6
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
# batch_size = 1
# nheads = 1
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
)
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
out, lse = flash_attn_func(q, k, v, causal=causal)
out_ref, attn_ref = attention_ref(
q,
@ -202,15 +204,15 @@ def test_flash_attn_output(
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# exp_sum = s_tmp.sum(-1)
qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
lse_ref = torch.logsumexp(qk, dim=-1)
# qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
# lse_ref = torch.logsumexp(qk, dim=-1)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if not causal:
print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
# if not causal:
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
# breakpoint()
# if d <= 128:
@ -248,5 +250,3 @@ def test_flash_attn_output(
# assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
# assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
# assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()

View File

@ -1,112 +1,26 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cutlass/fast_math.h"
#include "cutlass/arch/barrier.h"
#include "named_barrier.hpp"
namespace flash {
///////////////////////////////////////////////////////////////////////////////
class StaticPersistentTileSchedulerOld {
//
// Data members
//
private:
int current_work_linear_idx_;
cutlass::FastDivmod const &m_block_divmod, &head_divmod;
int const total_blocks;
public:
struct WorkTileInfo {
int M_idx = 0;
int H_idx = 0;
int B_idx = 0;
bool is_valid_tile = false;
CUTLASS_HOST_DEVICE
bool
is_valid() const {
return is_valid_tile;
}
CUTLASS_HOST_DEVICE
static WorkTileInfo
invalid_work_tile() {
return {-1, -1, -1, false};
}
};
public:
CUTLASS_DEVICE explicit StaticPersistentTileSchedulerOld(cutlass::FastDivmod const &m_block_divmod_,
cutlass::FastDivmod const &head_divmod_,
int const total_blocks_) :
m_block_divmod(m_block_divmod_), head_divmod(head_divmod_), total_blocks(total_blocks_) {
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__.
#if defined(__CUDA_ARCH__)
// current_work_linear_idx_ = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
current_work_linear_idx_ = blockIdx.x;
#else
CUTLASS_ASSERT(false && "This line should never be reached");
#endif
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work() const {
return get_current_work_for_linear_idx(current_work_linear_idx_);
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work_for_linear_idx(int linear_idx) const {
if (linear_idx >= total_blocks) {
return WorkTileInfo::invalid_work_tile();
}
// Map worker's linear index into the CTA tiled problem shape to the corresponding MHB indices
int M_idx, H_idx, B_idx;
int quotient = m_block_divmod.divmod(M_idx, linear_idx);
B_idx = head_divmod.divmod(H_idx, quotient);
return {M_idx, H_idx, B_idx, true};
}
CUTLASS_DEVICE
void
// advance_to_next_work(int advance_count = 1) {
advance_to_next_work() {
// current_work_linear_idx_ += int(gridDim.x * gridDim.y * gridDim.z);
current_work_linear_idx_ += int(gridDim.x);
}
CUTLASS_DEVICE
WorkTileInfo
fetch_next_work() {
WorkTileInfo new_work_tile_info;
advance_to_next_work();
new_work_tile_info = get_current_work();
return new_work_tile_info;
}
};
///////////////////////////////////////////////////////////////////////////////
class SingleTileScheduler {
struct SingleTileScheduler {
public:
// Host side kernel arguments
struct Arguments {
int const num_blocks_m, num_head, num_batch;
int const* tile_count_semaphore = nullptr;
int* const tile_count_semaphore = nullptr;
};
// Device side kernel params
@ -140,20 +54,30 @@ public:
return {M_idx, H_idx, B_idx};
}
CUTLASS_DEVICE
WorkTileInfo
get_next_work(Params const& params) const {
return {-1, -1, -1, false};
}
};
CUTLASS_DEVICE
SingleTileScheduler(int* tile_count_smem_) { }
CUTLASS_DEVICE
WorkTileInfo
get_initial_work() const {
return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
}
CUTLASS_DEVICE
void
init_consumer() const {}
CUTLASS_DEVICE
void
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
CUTLASS_DEVICE
void
broadcast_next_work(WorkTileInfo& current_work) const {}
template<bool IsProducer=false>
CUTLASS_DEVICE
WorkTileInfo
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
@ -171,7 +95,7 @@ public:
// Host side kernel arguments
struct Arguments {
int const num_blocks_m, num_head, num_batch;
int const* tile_count_semaphore = nullptr;
int* const tile_count_semaphore = nullptr;
};
// Device side kernel params
@ -210,12 +134,28 @@ public:
};
CUTLASS_DEVICE
StaticPersistentTileScheduler(int* tile_count_smem_) {};
CUTLASS_DEVICE
WorkTileInfo
get_initial_work() const {
return {int(blockIdx.x)};
}
CUTLASS_DEVICE
void
init_consumer() const {}
CUTLASS_DEVICE
void
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
CUTLASS_DEVICE
void
broadcast_next_work(WorkTileInfo& current_work) const {}
template<bool IsProducer=false>
CUTLASS_DEVICE
WorkTileInfo
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
@ -224,21 +164,25 @@ public:
};
template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup>
class DynamicPersistentTileScheduler {
protected:
int* const tile_count_smem;
public:
// Host side kernel arguments
struct Arguments {
int const num_blocks_m, num_head, num_batch;
int const* tile_count_semaphore;
int* const tile_count_semaphore;
};
// Device side kernel params
struct Params {
int const total_blocks;
cutlass::FastDivmod const m_block_divmod, head_divmod;
int const* tile_count_semaphore;
int* const tile_count_semaphore;
};
static Params
@ -253,25 +197,27 @@ public:
return {uint32_t(num_sm)};
}
using WorkTileInfo = StaticPersistentTileScheduler::WorkTileInfo;
// struct WorkTileInfo {
// int tile_idx;
struct WorkTileInfo {
int tile_idx;
// CUTLASS_DEVICE
// bool
// is_valid(Params const& params) const {
// return tile_idx < params.total_blocks;
// }
CUTLASS_DEVICE
bool
is_valid(Params const& params) const {
return tile_idx < params.total_blocks;
}
// CUTLASS_DEVICE
// cute::tuple<int32_t, int32_t, int32_t>
// get_block_coord(Params const& params) const {
// int m_block, bidh, bidb;
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
// return {m_block, bidh, bidb};
// }
CUTLASS_DEVICE
cute::tuple<int32_t, int32_t, int32_t>
get_block_coord(Params const& params) const {
int m_block, bidh, bidb;
bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
return {m_block, bidh, bidb};
}
// };
};
CUTLASS_DEVICE
DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {};
CUTLASS_DEVICE
WorkTileInfo
@ -279,12 +225,45 @@ public:
return {int(blockIdx.x)};
}
CUTLASS_DEVICE
void
init_consumer() const {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
}
CUTLASS_DEVICE
void
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
}
}
CUTLASS_DEVICE
void
broadcast_next_work(WorkTileInfo& current_work) const {
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
*tile_count_smem = current_work.tile_idx;
}
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
}
template<bool IsProducer=false>
CUTLASS_DEVICE
WorkTileInfo
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
return {current_work.tile_idx + int(gridDim.x)};
if constexpr (IsProducer) {
// thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
} else {
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
int tile_idx = *tile_count_smem;
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
return {tile_idx};
}
}
};
} // flash
} // flash