[FA3] BF16 forward
This commit is contained in:
parent
898dd4bbf2
commit
74b0761ff7
@ -1 +1 @@
|
||||
Subproject commit fa4f6359069bd4dd6fabd0cda2476dd8e72b3837
|
||||
Subproject commit 756c351b4994854b2f8c6dded3821ebbb580876b
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "named_barrier.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
namespace flash {
|
||||
@ -127,7 +128,7 @@ struct CollectiveEpilogueFwd {
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
// Make sure all WGs have finished reading V
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/);
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
|
||||
|
||||
@ -66,8 +66,6 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
||||
cutlass::FastDivmod head_divmod, m_block_divmod;
|
||||
int total_blocks;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
|
||||
@ -99,8 +99,6 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
params.d = d;
|
||||
params.d_rounded = d_rounded;
|
||||
|
||||
params.head_divmod = cutlass::FastDivmod(int(h));
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
@ -225,12 +223,22 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split
|
||||
// run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
|
||||
// });
|
||||
if (!params.is_e4m3) {
|
||||
if (params.d == 64) {
|
||||
run_mha_fwd_<cutlass::half_t, 64>(params, stream);
|
||||
} else if (params.d == 128) {
|
||||
run_mha_fwd_<cutlass::half_t, 128>(params, stream);
|
||||
if (params.is_bf16) {
|
||||
if (params.d == 64) {
|
||||
run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
|
||||
} else if (params.d == 128) {
|
||||
run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_mha_fwd_<cutlass::half_t, 256>(params, stream);
|
||||
if (params.d == 64) {
|
||||
run_mha_fwd_<cutlass::half_t, 64>(params, stream);
|
||||
} else if (params.d == 128) {
|
||||
run_mha_fwd_<cutlass::half_t, 128>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_<cutlass::half_t, 256>(params, stream);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
|
||||
@ -250,9 +258,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16,
|
||||
"FlashAttention only support fp16 data type for now");
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type for now");
|
||||
// TODO: will add e4m3 later
|
||||
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
|
||||
// "FlashAttention only support fp16 and bf16 data type");
|
||||
@ -278,10 +285,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
TORCH_CHECK(num_heads == num_heads_k, "We do not support MQA/GQA yet");
|
||||
|
||||
TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
|
||||
|
||||
@ -345,7 +351,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
/*window_size_left=*/-1,
|
||||
/*window_size_right=*/is_causal ? 0 : -1);
|
||||
|
||||
auto tile_count_semaphore = is_causal ? torch::full({1}, 132, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
|
||||
auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
|
||||
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
|
||||
9
hopper/flash_fwd_hdim128_bf16_sm90.cu
Normal file
9
hopper/flash_fwd_hdim128_bf16_sm90.cu
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
9
hopper/flash_fwd_hdim256_bf16_sm90.cu
Normal file
9
hopper/flash_fwd_hdim256_bf16_sm90.cu
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
9
hopper/flash_fwd_hdim64_bf16_sm90.cu
Normal file
9
hopper/flash_fwd_hdim64_bf16_sm90.cu
Normal file
@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2024, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@ -26,8 +26,7 @@ using namespace cute;
|
||||
|
||||
template <typename Ktraits, bool Is_causal, typename TileScheduler>
|
||||
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
|
||||
compute_attn_ws(CUTE_GRID_CONSTANT Flash_fwd_params const params,
|
||||
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
|
||||
compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits>::Params const epilogue_params,
|
||||
CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params
|
||||
) {
|
||||
@ -101,9 +100,6 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
||||
if (warp_group_idx == 0) { // Producer
|
||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 24 : 32>();
|
||||
// cutlass::arch::warpgroup_reg_dealloc<56>();
|
||||
// StaticPersistentTileScheduler scheduler{params.m_block_divmod, params.head_divmod, params.total_blocks};
|
||||
// auto work_tile_info = scheduler.get_current_work();
|
||||
TileScheduler scheduler;
|
||||
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
||||
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
|
||||
@ -112,20 +108,22 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
||||
|
||||
int work_idx = 0;
|
||||
|
||||
// auto get_tile_count = [&] () {
|
||||
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
|
||||
// return shared_storage.tile_count_semaphore;
|
||||
// };
|
||||
TileScheduler scheduler(&shared_storage.tile_count_semaphore);
|
||||
for (auto work_tile_info = scheduler.get_initial_work();
|
||||
work_tile_info.is_valid(scheduler_params);
|
||||
work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
|
||||
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
|
||||
auto [m_block, bidh, bidb] = block_coord;
|
||||
|
||||
// while (work_tile_info.is_valid()) {
|
||||
// for (int tile_count = blockIdx.x; tile_count < params.total_blocks; tile_count = get_tile_count()) {
|
||||
// for (int tile_count_semaphore = blockIdx.x; tile_count_semaphore < params.total_blocks; tile_count_semaphore = __shfl_sync(0xffffffff, tile_count_semaphore, 0)) {
|
||||
for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
|
||||
int tile_count_semaphore = 0;
|
||||
collective_mainloop.load(params, mainloop_params, scheduler_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
|
||||
shared_storage, work_tile_info, work_idx, tile_count_semaphore);
|
||||
// ++work_idx;
|
||||
// work_tile_info = scheduler.fetch_next_work();
|
||||
int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
|
||||
if (Is_causal && n_block_max <= 0) {
|
||||
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
|
||||
scheduler.broadcast_next_work(work_tile_info);
|
||||
continue;
|
||||
}
|
||||
collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
|
||||
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
|
||||
++work_idx;
|
||||
}
|
||||
collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
|
||||
}
|
||||
@ -133,44 +131,31 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 240 : 160>();
|
||||
// cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 224 : 160>();
|
||||
|
||||
TileScheduler scheduler(&shared_storage.tile_count_semaphore);
|
||||
// Initialize matmul objects.
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
|
||||
TileScheduler scheduler{};
|
||||
|
||||
PipelineState smem_pipe_read_k, smem_pipe_read_v;
|
||||
// We don't need separate variables smem_pip_release_k and smem_pipe_release_v
|
||||
// We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
|
||||
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
|
||||
|
||||
auto get_tile_count = [&] () {
|
||||
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
|
||||
return shared_storage.tile_count_semaphore;
|
||||
};
|
||||
|
||||
collective_mainloop.mma_init();
|
||||
scheduler.init_consumer();
|
||||
|
||||
int work_idx = 0;
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
// for (int work_idx = 0; work_idx * gridDim.x + blockIdx.x < params.total_blocks; ++work_idx) {
|
||||
// for (int tile_count_semaphore = blockIdx.x, work_idx = 0; tile_count_semaphore < params.total_blocks; tile_count_semaphore = get_tile_count()) {
|
||||
for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) {
|
||||
for (auto work_tile_info = scheduler.get_initial_work();
|
||||
work_tile_info.is_valid(scheduler_params);
|
||||
work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
|
||||
// Attention output (GEMM-II) accumulator.
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
|
||||
flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
|
||||
|
||||
// int m_block;
|
||||
// int bidh, bidb;
|
||||
// // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_idx * gridDim.x + blockIdx.x));
|
||||
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
|
||||
// cute::tuple<int32_t, int32_t, int32_t> block_coord = {m_block, bidh, bidb};
|
||||
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
|
||||
auto [m_block, bidh, bidb] = block_coord;
|
||||
|
||||
int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
|
||||
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
|
||||
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
|
||||
collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord);
|
||||
continue;
|
||||
}
|
||||
@ -178,15 +163,14 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
||||
collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
|
||||
tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage);
|
||||
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
|
||||
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, 0, shared_storage);
|
||||
collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
|
||||
threadIdx.x - NumCopyThreads, block_coord);
|
||||
|
||||
++work_idx;
|
||||
// work_tile_info = scheduler.fetch_next_work();
|
||||
}
|
||||
collective_epilogue.store_tail();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
|
||||
#include "static_switch.h"
|
||||
@ -26,8 +27,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
|
||||
using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal>;
|
||||
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits>;
|
||||
// using Scheduler = flash::SingleTileScheduler;
|
||||
using Scheduler = flash::StaticPersistentTileScheduler;
|
||||
using Scheduler = std::conditional_t<!Is_causal,
|
||||
flash::StaticPersistentTileScheduler,
|
||||
flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>>;
|
||||
// flash::SingleTileScheduler>;
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments({
|
||||
static_cast<Element const*>(params.q_ptr),
|
||||
@ -51,32 +54,35 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
|
||||
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
|
||||
typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b};
|
||||
typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore};
|
||||
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
|
||||
|
||||
// Get the ptr to kernel function.
|
||||
void *kernel;
|
||||
kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler>;
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
|
||||
int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
|
||||
int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
|
||||
int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
|
||||
// int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
|
||||
// int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
|
||||
// int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
|
||||
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiprocessor_count;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
params.m_block_divmod = cutlass::FastDivmod(num_blocks_m);
|
||||
params.total_blocks = num_blocks_m * params.h * params.b;
|
||||
// dim3 grid_dims(num_blocks_m, params.h, params.b);
|
||||
// dim3 grid_dims(132);
|
||||
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, 132);
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(launch_params, kernel, params, mainloop_params, epilogue_params, scheduler_params);
|
||||
// kernel<<<grid_dims, block_dims, smem_size, stream>>>(params, tma_load_Q, tma_load_K, tma_load_V, tma_store_O);
|
||||
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
@ -92,7 +98,10 @@ template<typename T>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
|
||||
// Only use Cluster if number of tiles along seqlen_q is even
|
||||
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -100,6 +109,9 @@ template<typename T>
|
||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 256;
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal ? 2 : 1, T>, Is_causal>(params, stream);
|
||||
// Only use Cluster if number of tiles along seqlen_q is even
|
||||
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "named_barrier.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
namespace flash {
|
||||
@ -108,6 +109,7 @@ struct CollectiveMainloopFwd {
|
||||
struct Params {
|
||||
ShapeQKV const shape_Q;
|
||||
ShapeQKV const shape_K;
|
||||
cutlass::FastDivmod qhead_per_khead_divmod;
|
||||
TMA_Q tma_load_Q;
|
||||
TMA_KV tma_load_K, tma_load_V;
|
||||
float const softmax_scale_log2;
|
||||
@ -137,7 +139,10 @@ struct CollectiveMainloopFwd {
|
||||
SmemLayoutV{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {args.shape_Q, args.shape_K, tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2};
|
||||
return {args.shape_Q, args.shape_K,
|
||||
cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
|
||||
tma_load_Q, tma_load_K, tma_load_V,
|
||||
args.softmax_scale_log2};
|
||||
}
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
@ -162,46 +167,21 @@ struct CollectiveMainloopFwd {
|
||||
return n_block_max;
|
||||
}
|
||||
|
||||
template <typename FullParams, typename SchedulerParams, typename SharedStorage, typename WorkTileInfo>
|
||||
template <typename Scheduler, typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
load(FullParams const& params,
|
||||
Params const& mainloop_params,
|
||||
SchedulerParams const& scheduler_params,
|
||||
load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_write_k,
|
||||
PipelineState& smem_pipe_write_v,
|
||||
SharedStorage &shared_storage,
|
||||
WorkTileInfo work_tile_info,
|
||||
int& work_idx,
|
||||
int& tile_count_semaphore
|
||||
Scheduler& scheduler,
|
||||
typename Scheduler::Params const& scheduler_params,
|
||||
typename Scheduler::WorkTileInfo& work_tile_info,
|
||||
cute::tuple<int32_t, int32_t, int32_t> block_coord,
|
||||
int work_idx
|
||||
) {
|
||||
|
||||
static constexpr int kBlockM = get<0>(TileShape_MNK{});
|
||||
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
||||
|
||||
// int const m_block = work_tile_info.M_idx;
|
||||
// int const bidh = work_tile_info.H_idx;
|
||||
// int const bidb = work_tile_info.B_idx;
|
||||
// int m_block;
|
||||
// int bidh, bidb;
|
||||
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
|
||||
auto [m_block, bidh, bidb] = work_tile_info.get_block_coord(scheduler_params);
|
||||
// if (threadIdx.x == 0) { printf("producer, blockIdx.x = %d, bidb = %d, bidh = %d, m_block = %d\n", blockIdx.x, bidb, bidh, m_block); }
|
||||
|
||||
int n_block_max = get_n_block_max(mainloop_params, m_block);
|
||||
if (Is_causal && n_block_max <= 0) {
|
||||
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 7 /*id*/);
|
||||
// if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
|
||||
// tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
|
||||
// shared_storage.tile_count_semaphore = tile_count_semaphore;
|
||||
// }
|
||||
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
|
||||
@ -210,13 +190,16 @@ struct CollectiveMainloopFwd {
|
||||
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K);
|
||||
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K);
|
||||
|
||||
auto [m_block, bidh, bidb] = block_coord;
|
||||
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
|
||||
|
||||
// Prepare the TMA loads
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
|
||||
Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
||||
Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
||||
Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
||||
Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _)
|
||||
|
||||
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
|
||||
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
|
||||
@ -235,6 +218,7 @@ struct CollectiveMainloopFwd {
|
||||
}
|
||||
}
|
||||
|
||||
int n_block_max = get_n_block_max(mainloop_params, m_block);
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
@ -246,7 +230,7 @@ struct CollectiveMainloopFwd {
|
||||
}
|
||||
|
||||
// Wait for the MMA warpgroups to say that smem_q is ready
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
|
||||
|
||||
if (lane_predicate) {
|
||||
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
|
||||
@ -272,22 +256,14 @@ struct CollectiveMainloopFwd {
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
|
||||
// tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
|
||||
}
|
||||
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
|
||||
if (lane_predicate) {
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
|
||||
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
|
||||
// printf("blockIdx.x = %d, tile_count_semaphore: %d\n", blockIdx.x, tile_count_semaphore);
|
||||
// shared_storage.tile_count_semaphore = tile_count_semaphore;
|
||||
}
|
||||
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 10 /*id*/);
|
||||
++work_idx;
|
||||
scheduler.broadcast_next_work(work_tile_info);
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
@ -307,36 +283,36 @@ struct CollectiveMainloopFwd {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
scheduler_barrier_sync() {
|
||||
warp_scheduler_barrier_sync() {
|
||||
if constexpr (UseSchedulerBarrier) {
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 3 + cutlass::canonical_warp_group_idx() /*id*/);
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
scheduler_barrier_arrive() {
|
||||
warp_scheduler_barrier_arrive() {
|
||||
if constexpr (!UseSchedulerBarrier) { return; }
|
||||
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
||||
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
|
||||
} else {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 2 ? 3 + cutlass::canonical_warp_group_idx() + 1 : 3 + cutlass::canonical_warp_group_idx() + 1 - 3 /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, cutlass::canonical_warp_group_idx() <= 1 ? 3 + cutlass::canonical_warp_group_idx() + 2 : 3 + cutlass::canonical_warp_group_idx() + 2 - 3 /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
mma_init() {
|
||||
// Tell producer (warp 0) that smem_q is ready
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
|
||||
if constexpr (!UseSchedulerBarrier) { return; }
|
||||
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
||||
if (cutlass::canonical_warp_group_idx() > 1) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 1 /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
|
||||
}
|
||||
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
|
||||
if (cutlass::canonical_warp_group_idx() > 2) {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 3 + 2 /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
@ -393,9 +369,9 @@ struct CollectiveMainloopFwd {
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
scheduler_barrier_sync();
|
||||
warp_scheduler_barrier_sync();
|
||||
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
scheduler_barrier_arrive();
|
||||
warp_scheduler_barrier_arrive();
|
||||
if (work_idx != 0) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
|
||||
@ -443,12 +419,12 @@ struct CollectiveMainloopFwd {
|
||||
for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
scheduler_barrier_sync();
|
||||
warp_scheduler_barrier_sync();
|
||||
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
scheduler_barrier_arrive();
|
||||
warp_scheduler_barrier_arrive();
|
||||
warpgroup_wait<1>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
|
||||
@ -472,12 +448,12 @@ struct CollectiveMainloopFwd {
|
||||
for (; n_block > 0; --n_block) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
scheduler_barrier_sync();
|
||||
warp_scheduler_barrier_sync();
|
||||
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
scheduler_barrier_arrive();
|
||||
warp_scheduler_barrier_arrive();
|
||||
warpgroup_wait<1>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
||||
// auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
|
||||
@ -491,7 +467,7 @@ struct CollectiveMainloopFwd {
|
||||
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
|
||||
}
|
||||
// Tell warp 0 that smem_q is ready
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, 1 /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
|
||||
23
hopper/named_barrier.hpp
Normal file
23
hopper/named_barrier.hpp
Normal file
@ -0,0 +1,23 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Enumerates the reserved named barriers to avoid potential conflicts
|
||||
enum class FwdNamedBarriers {
|
||||
QueryEmpty = 0,
|
||||
ValueEmpty = 1,
|
||||
TileCountSmemEmpty = 2,
|
||||
TileCountSmemFull = 3,
|
||||
WarpSchedulerWG1 = 4,
|
||||
WarpSchedulerWG2 = 5,
|
||||
WarpSchedulerWG3 = 6,
|
||||
};
|
||||
|
||||
} // flash
|
||||
@ -111,8 +111,11 @@ if not SKIP_CUDA_BUILD:
|
||||
sources = [
|
||||
"flash_api.cpp",
|
||||
"flash_fwd_hdim64_fp16_sm90.cu",
|
||||
"flash_fwd_hdim64_bf16_sm90.cu",
|
||||
"flash_fwd_hdim128_fp16_sm90.cu",
|
||||
"flash_fwd_hdim128_bf16_sm90.cu",
|
||||
"flash_fwd_hdim256_fp16_sm90.cu",
|
||||
"flash_fwd_hdim256_bf16_sm90.cu",
|
||||
"flash_bwd_hdim64_fp16_sm90.cu",
|
||||
"flash_bwd_hdim128_fp16_sm90.cu",
|
||||
"flash_bwd_hdim256_fp16_sm90.cu",
|
||||
|
||||
@ -131,15 +131,18 @@ def attention_ref(
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["gqa"])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [False])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
@pytest.mark.parametrize("d", [64, 128, 256])
|
||||
# @pytest.mark.parametrize("d", [128])
|
||||
# @pytest.mark.parametrize("d", [256])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
@ -151,6 +154,8 @@ def attention_ref(
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(384, 256),
|
||||
(640, 128),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
@ -160,7 +165,7 @@ def attention_ref(
|
||||
)
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
|
||||
def test_flash_attn_output(
|
||||
seqlen_q, seqlen_k, d, causal, dtype
|
||||
seqlen_q, seqlen_k, d, causal, mha_type, dtype
|
||||
):
|
||||
device = "cuda"
|
||||
# set seed
|
||||
@ -168,16 +173,13 @@ def test_flash_attn_output(
|
||||
# batch_size = 40
|
||||
# nheads = 16
|
||||
batch_size = 9
|
||||
nheads = 4
|
||||
nheads = 6
|
||||
nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||
# batch_size = 1
|
||||
# nheads = 1
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(
|
||||
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
|
||||
out, lse = flash_attn_func(q, k, v, causal=causal)
|
||||
out_ref, attn_ref = attention_ref(
|
||||
q,
|
||||
@ -202,15 +204,15 @@ def test_flash_attn_output(
|
||||
# m = qk.amax(-1, keepdim=True)
|
||||
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
|
||||
# exp_sum = s_tmp.sum(-1)
|
||||
qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
|
||||
lse_ref = torch.logsumexp(qk, dim=-1)
|
||||
# qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
|
||||
# lse_ref = torch.logsumexp(qk, dim=-1)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
if not causal:
|
||||
print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
|
||||
# if not causal:
|
||||
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
|
||||
# breakpoint()
|
||||
|
||||
# if d <= 128:
|
||||
@ -248,5 +250,3 @@ def test_flash_attn_output(
|
||||
# assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
# assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
# assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
|
||||
|
||||
|
||||
@ -1,112 +1,26 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "named_barrier.hpp"
|
||||
|
||||
namespace flash {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class StaticPersistentTileSchedulerOld {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
private:
|
||||
int current_work_linear_idx_;
|
||||
cutlass::FastDivmod const &m_block_divmod, &head_divmod;
|
||||
int const total_blocks;
|
||||
|
||||
public:
|
||||
struct WorkTileInfo {
|
||||
int M_idx = 0;
|
||||
int H_idx = 0;
|
||||
int B_idx = 0;
|
||||
bool is_valid_tile = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_valid() const {
|
||||
return is_valid_tile;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static WorkTileInfo
|
||||
invalid_work_tile() {
|
||||
return {-1, -1, -1, false};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
CUTLASS_DEVICE explicit StaticPersistentTileSchedulerOld(cutlass::FastDivmod const &m_block_divmod_,
|
||||
cutlass::FastDivmod const &head_divmod_,
|
||||
int const total_blocks_) :
|
||||
m_block_divmod(m_block_divmod_), head_divmod(head_divmod_), total_blocks(total_blocks_) {
|
||||
|
||||
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
|
||||
// like blockIdx and gridDim, with __CUDA_ARCH__.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
// current_work_linear_idx_ = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
|
||||
current_work_linear_idx_ = blockIdx.x;
|
||||
#else
|
||||
CUTLASS_ASSERT(false && "This line should never be reached");
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work() const {
|
||||
return get_current_work_for_linear_idx(current_work_linear_idx_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work_for_linear_idx(int linear_idx) const {
|
||||
if (linear_idx >= total_blocks) {
|
||||
return WorkTileInfo::invalid_work_tile();
|
||||
}
|
||||
|
||||
// Map worker's linear index into the CTA tiled problem shape to the corresponding MHB indices
|
||||
int M_idx, H_idx, B_idx;
|
||||
int quotient = m_block_divmod.divmod(M_idx, linear_idx);
|
||||
B_idx = head_divmod.divmod(H_idx, quotient);
|
||||
return {M_idx, H_idx, B_idx, true};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
// advance_to_next_work(int advance_count = 1) {
|
||||
advance_to_next_work() {
|
||||
// current_work_linear_idx_ += int(gridDim.x * gridDim.y * gridDim.z);
|
||||
current_work_linear_idx_ += int(gridDim.x);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
fetch_next_work() {
|
||||
WorkTileInfo new_work_tile_info;
|
||||
advance_to_next_work();
|
||||
new_work_tile_info = get_current_work();
|
||||
return new_work_tile_info;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class SingleTileScheduler {
|
||||
struct SingleTileScheduler {
|
||||
|
||||
public:
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
int const num_blocks_m, num_head, num_batch;
|
||||
int const* tile_count_semaphore = nullptr;
|
||||
int* const tile_count_semaphore = nullptr;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
@ -140,20 +54,30 @@ public:
|
||||
return {M_idx, H_idx, B_idx};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_next_work(Params const& params) const {
|
||||
return {-1, -1, -1, false};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SingleTileScheduler(int* tile_count_smem_) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_initial_work() const {
|
||||
return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
init_consumer() const {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
broadcast_next_work(WorkTileInfo& current_work) const {}
|
||||
|
||||
template<bool IsProducer=false>
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
|
||||
@ -171,7 +95,7 @@ public:
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
int const num_blocks_m, num_head, num_batch;
|
||||
int const* tile_count_semaphore = nullptr;
|
||||
int* const tile_count_semaphore = nullptr;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
@ -210,12 +134,28 @@ public:
|
||||
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
StaticPersistentTileScheduler(int* tile_count_smem_) {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_initial_work() const {
|
||||
return {int(blockIdx.x)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
init_consumer() const {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
broadcast_next_work(WorkTileInfo& current_work) const {}
|
||||
|
||||
template<bool IsProducer=false>
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
|
||||
@ -224,21 +164,25 @@ public:
|
||||
|
||||
};
|
||||
|
||||
template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup>
|
||||
class DynamicPersistentTileScheduler {
|
||||
|
||||
protected:
|
||||
int* const tile_count_smem;
|
||||
|
||||
public:
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
int const num_blocks_m, num_head, num_batch;
|
||||
int const* tile_count_semaphore;
|
||||
int* const tile_count_semaphore;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
int const total_blocks;
|
||||
cutlass::FastDivmod const m_block_divmod, head_divmod;
|
||||
int const* tile_count_semaphore;
|
||||
int* const tile_count_semaphore;
|
||||
};
|
||||
|
||||
static Params
|
||||
@ -253,25 +197,27 @@ public:
|
||||
return {uint32_t(num_sm)};
|
||||
}
|
||||
|
||||
using WorkTileInfo = StaticPersistentTileScheduler::WorkTileInfo;
|
||||
// struct WorkTileInfo {
|
||||
// int tile_idx;
|
||||
struct WorkTileInfo {
|
||||
int tile_idx;
|
||||
|
||||
// CUTLASS_DEVICE
|
||||
// bool
|
||||
// is_valid(Params const& params) const {
|
||||
// return tile_idx < params.total_blocks;
|
||||
// }
|
||||
CUTLASS_DEVICE
|
||||
bool
|
||||
is_valid(Params const& params) const {
|
||||
return tile_idx < params.total_blocks;
|
||||
}
|
||||
|
||||
// CUTLASS_DEVICE
|
||||
// cute::tuple<int32_t, int32_t, int32_t>
|
||||
// get_block_coord(Params const& params) const {
|
||||
// int m_block, bidh, bidb;
|
||||
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
|
||||
// return {m_block, bidh, bidb};
|
||||
// }
|
||||
CUTLASS_DEVICE
|
||||
cute::tuple<int32_t, int32_t, int32_t>
|
||||
get_block_coord(Params const& params) const {
|
||||
int m_block, bidh, bidb;
|
||||
bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
|
||||
return {m_block, bidh, bidb};
|
||||
}
|
||||
|
||||
// };
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
@ -279,12 +225,45 @@ public:
|
||||
return {int(blockIdx.x)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
init_consumer() const {
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
|
||||
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
|
||||
current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
broadcast_next_work(WorkTileInfo& current_work) const {
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
|
||||
if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
|
||||
*tile_count_smem = current_work.tile_idx;
|
||||
}
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
|
||||
}
|
||||
|
||||
template<bool IsProducer=false>
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
|
||||
return {current_work.tile_idx + int(gridDim.x)};
|
||||
if constexpr (IsProducer) {
|
||||
// thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
|
||||
return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
|
||||
} else {
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
|
||||
int tile_idx = *tile_count_smem;
|
||||
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
|
||||
return {tile_idx};
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // flash
|
||||
} // flash
|
||||
Loading…
Reference in New Issue
Block a user