Refactor some GroupedGEMM logic (#1899)
This commit is contained in:
parent
08a49953a0
commit
e8a8b69365
@ -323,8 +323,8 @@ tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc,
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc,
|
||||
cute::array<uint32_t, 3> const& prob_shape,
|
||||
cute::array<uint64_t, 3> const& prob_stride)
|
||||
cute::array<uint32_t, 5> const& prob_shape,
|
||||
cute::array<uint64_t, 5> const& prob_stride)
|
||||
{
|
||||
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
|
||||
uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
|
||||
@ -341,6 +341,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(prob_shape[2]));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 3, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(prob_shape[3]));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 4, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(prob_shape[4]));
|
||||
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
|
||||
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5)))
|
||||
asm volatile (
|
||||
@ -349,6 +355,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[2]));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[3]));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[4]));
|
||||
#else
|
||||
// 4 LSBs are not included
|
||||
asm volatile (
|
||||
@ -357,6 +369,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[3] >> 4));
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[4] >> 4));
|
||||
#endif
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3");
|
||||
|
||||
@ -456,7 +456,6 @@ public:
|
||||
tensormaps_cp_fence_release(
|
||||
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
|
||||
[[maybe_unused]] cute::TmaDescriptor const* tensormap,
|
||||
[[maybe_unused]] uint32_t lane_predicate,
|
||||
[[maybe_unused]] int32_t warp_group_idx) { }
|
||||
|
||||
template <bool IsLoad>
|
||||
|
||||
@ -1080,13 +1080,10 @@ public:
|
||||
int32_t warp_group_idx) {
|
||||
const uint32_t M = get<0>(problem_shape_mnkl);
|
||||
const uint32_t N = get<1>(problem_shape_mnkl);
|
||||
// Only consider dimensions and strides that we need to recalculate and replace for each group
|
||||
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
|
||||
static_assert(TensorRank == Int<3>{},
|
||||
"Descriptor modification for global dims & strides expects rank as 3.");
|
||||
|
||||
cute::array<uint32_t, TensorRank> prob_shape = {1,1,1};
|
||||
cute::array<uint64_t, TensorRank> prob_stride = {0,0,0};
|
||||
// Replace all dims for consistency
|
||||
constexpr int MaxTensorRank = 5;
|
||||
cute::array<uint32_t, MaxTensorRank> prob_shape = {1,1,1,1,1};
|
||||
cute::array<uint64_t, MaxTensorRank> prob_stride = {0,0,0,0,0};
|
||||
|
||||
if constexpr (IsLoad) {
|
||||
if constexpr (is_source_supported) {
|
||||
@ -1106,9 +1103,6 @@ public:
|
||||
}
|
||||
else if constexpr (is_destination_supported) {
|
||||
ElementD const* ptr_D = nullptr;
|
||||
|
||||
// tma_store_c should be a gmem_tensor, second argument should be a stride
|
||||
|
||||
Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group]));
|
||||
|
||||
cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d,
|
||||
@ -1154,8 +1148,7 @@ public:
|
||||
tensormaps_cp_fence_release(
|
||||
TensorMapStorage& shared_tensormaps,
|
||||
cute::TmaDescriptor const* tensormap,
|
||||
[[maybe_unused]] uint32_t lane_predicate,
|
||||
int32_t warp_group_idx = 0) {
|
||||
const int32_t warp_group_idx = 0) {
|
||||
|
||||
// Entire warp must do this (ie its aligned)
|
||||
if constexpr (IsLoad) {
|
||||
|
||||
@ -674,14 +674,12 @@ struct CollectiveMma<
|
||||
const uint32_t M = get<0>(problem_shape_mnkl);
|
||||
const uint32_t N = get<1>(problem_shape_mnkl);
|
||||
const uint32_t K = get<2>(problem_shape_mnkl);
|
||||
// Only consider dimensions and strides that we need to recalculate and replace for each group
|
||||
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
|
||||
static_assert(TensorRank == Int<3>{},
|
||||
"Descriptor modification for global dims & strides expects rank as 3.");
|
||||
cute::array<uint32_t, TensorRank> prob_shape_A = {1,1,1};
|
||||
cute::array<uint64_t, TensorRank> prob_stride_A = {0,0,0};
|
||||
cute::array<uint32_t, TensorRank> prob_shape_B = {1,1,1};
|
||||
cute::array<uint64_t, TensorRank> prob_stride_B = {0,0,0};
|
||||
// Replace all dims for consistency
|
||||
constexpr int MaxTensorRank = 5;
|
||||
cute::array<uint32_t, MaxTensorRank> prob_shape_A = {1,1,1,1,1};
|
||||
cute::array<uint64_t, MaxTensorRank> prob_stride_A = {0,0,0,0,0};
|
||||
cute::array<uint32_t, MaxTensorRank> prob_shape_B = {1,1,1,1,1};
|
||||
cute::array<uint64_t, MaxTensorRank> prob_stride_B = {0,0,0,0,0};
|
||||
|
||||
InternalElementA const* ptr_A = nullptr;
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);
|
||||
|
||||
@ -499,7 +499,7 @@ public:
|
||||
}
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
@ -595,7 +595,7 @@ public:
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
curr_batch = next_batch;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1);
|
||||
}
|
||||
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
|
||||
// Since this state is waiting for loads to finish, it must start in the inverted phase.
|
||||
@ -644,7 +644,7 @@ public:
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0);
|
||||
}
|
||||
|
||||
load_order_barrier.wait();
|
||||
@ -657,7 +657,7 @@ public:
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
@ -692,7 +692,7 @@ public:
|
||||
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
|
||||
// tensormap update
|
||||
@ -708,7 +708,7 @@ public:
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -749,16 +749,15 @@ public:
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
epi_store_tensormap,
|
||||
lane_predicate,
|
||||
consumer_warp_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
while (work_tile_info.is_valid()) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
|
||||
int32_t curr_batch = work_tile_info.L_idx;
|
||||
@ -841,7 +840,7 @@ public:
|
||||
did_batch_change = curr_batch != work_tile_info.L_idx;
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
if (warp_idx_in_warp_group == 0) {
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
@ -857,7 +856,6 @@ public:
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
epi_store_tensormap,
|
||||
lane_predicate,
|
||||
consumer_warp_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -514,7 +514,7 @@ public:
|
||||
}
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
@ -531,7 +531,7 @@ public:
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors where:
|
||||
@ -628,7 +628,7 @@ public:
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
curr_batch = next_batch;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1);
|
||||
}
|
||||
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
|
||||
// Since this state is waiting for loads to finish, it must start in the inverted phase.
|
||||
@ -677,7 +677,7 @@ public:
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0);
|
||||
}
|
||||
|
||||
load_order_barrier.wait();
|
||||
@ -690,7 +690,7 @@ public:
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
@ -725,7 +725,7 @@ public:
|
||||
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
|
||||
// tensormap update
|
||||
@ -741,7 +741,7 @@ public:
|
||||
|
||||
// Converge before issuing tensormap fence release since fence is aligned
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -784,14 +784,13 @@ public:
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
epi_store_tensormap,
|
||||
lane_predicate,
|
||||
consumer_warp_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
while (work_tile_info.is_valid()) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
|
||||
int32_t curr_batch = work_tile_info.L_idx;
|
||||
@ -880,7 +879,7 @@ public:
|
||||
// Skip a tile for pingpong
|
||||
if (work_tile_info.is_valid()) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
@ -895,7 +894,7 @@ public:
|
||||
did_batch_change = curr_batch != work_tile_info.L_idx;
|
||||
if (work_tile_info.is_valid() && did_batch_change) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
|
||||
}
|
||||
if (warp_idx_in_warp_group == 0) {
|
||||
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
|
||||
@ -911,7 +910,6 @@ public:
|
||||
__syncwarp();
|
||||
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
|
||||
epi_store_tensormap,
|
||||
lane_predicate,
|
||||
consumer_warp_group_idx);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user