291 lines
7.4 KiB
C++
291 lines
7.4 KiB
C++
|
|
/******************************************************************************
|
||
|
|
* Copyright (c) 2024, Tri Dao.
|
||
|
|
******************************************************************************/
|
||
|
|
|
||
|
|
#pragma once
|
||
|
|
|
||
|
|
#include "cutlass/fast_math.h"
|
||
|
|
|
||
|
|
namespace flash {
|
||
|
|
|
||
|
|
///////////////////////////////////////////////////////////////////////////////
|
||
|
|
|
||
|
|
class StaticPersistentTileSchedulerOld {
|
||
|
|
//
|
||
|
|
// Data members
|
||
|
|
//
|
||
|
|
|
||
|
|
private:
|
||
|
|
int current_work_linear_idx_;
|
||
|
|
cutlass::FastDivmod const &m_block_divmod, &head_divmod;
|
||
|
|
int const total_blocks;
|
||
|
|
|
||
|
|
public:
|
||
|
|
struct WorkTileInfo {
|
||
|
|
int M_idx = 0;
|
||
|
|
int H_idx = 0;
|
||
|
|
int B_idx = 0;
|
||
|
|
bool is_valid_tile = false;
|
||
|
|
|
||
|
|
CUTLASS_HOST_DEVICE
|
||
|
|
bool
|
||
|
|
is_valid() const {
|
||
|
|
return is_valid_tile;
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_HOST_DEVICE
|
||
|
|
static WorkTileInfo
|
||
|
|
invalid_work_tile() {
|
||
|
|
return {-1, -1, -1, false};
|
||
|
|
}
|
||
|
|
|
||
|
|
};
|
||
|
|
|
||
|
|
public:
|
||
|
|
|
||
|
|
CUTLASS_DEVICE explicit StaticPersistentTileSchedulerOld(cutlass::FastDivmod const &m_block_divmod_,
|
||
|
|
cutlass::FastDivmod const &head_divmod_,
|
||
|
|
int const total_blocks_) :
|
||
|
|
m_block_divmod(m_block_divmod_), head_divmod(head_divmod_), total_blocks(total_blocks_) {
|
||
|
|
|
||
|
|
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
|
||
|
|
// like blockIdx and gridDim, with __CUDA_ARCH__.
|
||
|
|
#if defined(__CUDA_ARCH__)
|
||
|
|
// current_work_linear_idx_ = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
|
||
|
|
current_work_linear_idx_ = blockIdx.x;
|
||
|
|
#else
|
||
|
|
CUTLASS_ASSERT(false && "This line should never be reached");
|
||
|
|
#endif
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_current_work() const {
|
||
|
|
return get_current_work_for_linear_idx(current_work_linear_idx_);
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_current_work_for_linear_idx(int linear_idx) const {
|
||
|
|
if (linear_idx >= total_blocks) {
|
||
|
|
return WorkTileInfo::invalid_work_tile();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Map worker's linear index into the CTA tiled problem shape to the corresponding MHB indices
|
||
|
|
int M_idx, H_idx, B_idx;
|
||
|
|
int quotient = m_block_divmod.divmod(M_idx, linear_idx);
|
||
|
|
B_idx = head_divmod.divmod(H_idx, quotient);
|
||
|
|
return {M_idx, H_idx, B_idx, true};
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
void
|
||
|
|
// advance_to_next_work(int advance_count = 1) {
|
||
|
|
advance_to_next_work() {
|
||
|
|
// current_work_linear_idx_ += int(gridDim.x * gridDim.y * gridDim.z);
|
||
|
|
current_work_linear_idx_ += int(gridDim.x);
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
fetch_next_work() {
|
||
|
|
WorkTileInfo new_work_tile_info;
|
||
|
|
advance_to_next_work();
|
||
|
|
new_work_tile_info = get_current_work();
|
||
|
|
return new_work_tile_info;
|
||
|
|
}
|
||
|
|
|
||
|
|
};
|
||
|
|
|
||
|
|
///////////////////////////////////////////////////////////////////////////////
|
||
|
|
|
||
|
|
class SingleTileScheduler {
|
||
|
|
|
||
|
|
public:
|
||
|
|
|
||
|
|
// Host side kernel arguments
|
||
|
|
struct Arguments {
|
||
|
|
int const num_blocks_m, num_head, num_batch;
|
||
|
|
int const* tile_count_semaphore = nullptr;
|
||
|
|
};
|
||
|
|
|
||
|
|
// Device side kernel params
|
||
|
|
struct Params {};
|
||
|
|
|
||
|
|
static Params
|
||
|
|
to_underlying_arguments(Arguments const& args) {
|
||
|
|
return {};
|
||
|
|
}
|
||
|
|
|
||
|
|
static dim3
|
||
|
|
get_grid_dim(Arguments const& args, int num_sm) {
|
||
|
|
return {uint32_t(args.num_blocks_m), uint32_t(args.num_head), uint32_t(args.num_batch)};
|
||
|
|
}
|
||
|
|
|
||
|
|
struct WorkTileInfo {
|
||
|
|
int M_idx = 0;
|
||
|
|
int H_idx = 0;
|
||
|
|
int B_idx = 0;
|
||
|
|
bool is_valid_tile = false;
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
bool
|
||
|
|
is_valid(Params const& params) const {
|
||
|
|
return is_valid_tile;
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
cute::tuple<int32_t, int32_t, int32_t>
|
||
|
|
get_block_coord(Params const& params) const {
|
||
|
|
return {M_idx, H_idx, B_idx};
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_next_work(Params const& params) const {
|
||
|
|
return {-1, -1, -1, false};
|
||
|
|
}
|
||
|
|
|
||
|
|
};
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_initial_work() const {
|
||
|
|
return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true};
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
|
||
|
|
return {-1, -1, -1, false};
|
||
|
|
}
|
||
|
|
|
||
|
|
};
|
||
|
|
|
||
|
|
///////////////////////////////////////////////////////////////////////////////
|
||
|
|
|
||
|
|
class StaticPersistentTileScheduler {
|
||
|
|
|
||
|
|
public:
|
||
|
|
|
||
|
|
// Host side kernel arguments
|
||
|
|
struct Arguments {
|
||
|
|
int const num_blocks_m, num_head, num_batch;
|
||
|
|
int const* tile_count_semaphore = nullptr;
|
||
|
|
};
|
||
|
|
|
||
|
|
// Device side kernel params
|
||
|
|
struct Params {
|
||
|
|
int total_blocks;
|
||
|
|
cutlass::FastDivmod m_block_divmod, head_divmod;
|
||
|
|
};
|
||
|
|
|
||
|
|
static Params
|
||
|
|
to_underlying_arguments(Arguments const& args) {
|
||
|
|
return {args.num_blocks_m * args.num_head * args.num_batch,
|
||
|
|
cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)};
|
||
|
|
}
|
||
|
|
|
||
|
|
static dim3
|
||
|
|
get_grid_dim(Arguments const& args, int num_sm) {
|
||
|
|
return {uint32_t(num_sm)};
|
||
|
|
}
|
||
|
|
|
||
|
|
struct WorkTileInfo {
|
||
|
|
int tile_idx;
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
bool
|
||
|
|
is_valid(Params const& params) const {
|
||
|
|
return tile_idx < params.total_blocks;
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
cute::tuple<int32_t, int32_t, int32_t>
|
||
|
|
get_block_coord(Params const& params) const {
|
||
|
|
int m_block, bidh, bidb;
|
||
|
|
bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
|
||
|
|
return {m_block, bidh, bidb};
|
||
|
|
}
|
||
|
|
|
||
|
|
};
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_initial_work() const {
|
||
|
|
return {int(blockIdx.x)};
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
|
||
|
|
return {current_work.tile_idx + int(gridDim.x)};
|
||
|
|
}
|
||
|
|
|
||
|
|
};
|
||
|
|
|
||
|
|
class DynamicPersistentTileScheduler {
|
||
|
|
|
||
|
|
public:
|
||
|
|
|
||
|
|
// Host side kernel arguments
|
||
|
|
struct Arguments {
|
||
|
|
int const num_blocks_m, num_head, num_batch;
|
||
|
|
int const* tile_count_semaphore;
|
||
|
|
};
|
||
|
|
|
||
|
|
// Device side kernel params
|
||
|
|
struct Params {
|
||
|
|
int const total_blocks;
|
||
|
|
cutlass::FastDivmod const m_block_divmod, head_divmod;
|
||
|
|
int const* tile_count_semaphore;
|
||
|
|
};
|
||
|
|
|
||
|
|
static Params
|
||
|
|
to_underlying_arguments(Arguments const& args) {
|
||
|
|
return {args.num_blocks_m * args.num_head * args.num_batch,
|
||
|
|
cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head),
|
||
|
|
args.tile_count_semaphore};
|
||
|
|
}
|
||
|
|
|
||
|
|
static dim3
|
||
|
|
get_grid_dim(Arguments const& args, int num_sm) {
|
||
|
|
return {uint32_t(num_sm)};
|
||
|
|
}
|
||
|
|
|
||
|
|
using WorkTileInfo = StaticPersistentTileScheduler::WorkTileInfo;
|
||
|
|
// struct WorkTileInfo {
|
||
|
|
// int tile_idx;
|
||
|
|
|
||
|
|
// CUTLASS_DEVICE
|
||
|
|
// bool
|
||
|
|
// is_valid(Params const& params) const {
|
||
|
|
// return tile_idx < params.total_blocks;
|
||
|
|
// }
|
||
|
|
|
||
|
|
// CUTLASS_DEVICE
|
||
|
|
// cute::tuple<int32_t, int32_t, int32_t>
|
||
|
|
// get_block_coord(Params const& params) const {
|
||
|
|
// int m_block, bidh, bidb;
|
||
|
|
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
|
||
|
|
// return {m_block, bidh, bidb};
|
||
|
|
// }
|
||
|
|
|
||
|
|
// };
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_initial_work() const {
|
||
|
|
return {int(blockIdx.x)};
|
||
|
|
}
|
||
|
|
|
||
|
|
CUTLASS_DEVICE
|
||
|
|
WorkTileInfo
|
||
|
|
get_next_work(Params const& params, WorkTileInfo const& current_work) const {
|
||
|
|
return {current_work.tile_idx + int(gridDim.x)};
|
||
|
|
}
|
||
|
|
|
||
|
|
};
|
||
|
|
|
||
|
|
} // flash
|