Refactor some GroupedGEMM logic (#1899)

This commit is contained in:
azhurkevich 2024-10-25 17:14:01 -07:00 committed by GitHub
parent 08a49953a0
commit e8a8b69365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 50 additions and 46 deletions

View File

@ -323,8 +323,8 @@ tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc,
CUTE_HOST_DEVICE
void
tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc,
cute::array<uint32_t, 3> const& prob_shape,
cute::array<uint64_t, 3> const& prob_stride)
cute::array<uint32_t, 5> const& prob_shape,
cute::array<uint64_t, 5> const& prob_stride)
{
#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED)
uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc);
@ -341,6 +341,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
asm volatile (
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;"
:: "l"(smem_int64_desc), "r"(prob_shape[2]));
asm volatile (
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 3, %1;"
:: "l"(smem_int64_desc), "r"(prob_shape[3]));
asm volatile (
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 4, %1;"
:: "l"(smem_int64_desc), "r"(prob_shape[4]));
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5)))
asm volatile (
@ -349,6 +355,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[2]));
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[3]));
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[4]));
#else
// 4 LSBs are not included
asm volatile (
@ -357,6 +369,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4));
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[3] >> 4));
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[4] >> 4));
#endif
#else
CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3");

View File

@ -456,7 +456,6 @@ public:
tensormaps_cp_fence_release(
[[maybe_unused]] TensorMapStorage& shared_tensormaps,
[[maybe_unused]] cute::TmaDescriptor const* tensormap,
[[maybe_unused]] uint32_t lane_predicate,
[[maybe_unused]] int32_t warp_group_idx) { }
template <bool IsLoad>

View File

@ -1080,13 +1080,10 @@ public:
int32_t warp_group_idx) {
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
// Only consider dimensions and strides that we need to recalculate and replace for each group
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
static_assert(TensorRank == Int<3>{},
"Descriptor modification for global dims & strides expects rank as 3.");
cute::array<uint32_t, TensorRank> prob_shape = {1,1,1};
cute::array<uint64_t, TensorRank> prob_stride = {0,0,0};
// Replace all dims for consistency
constexpr int MaxTensorRank = 5;
cute::array<uint32_t, MaxTensorRank> prob_shape = {1,1,1,1,1};
cute::array<uint64_t, MaxTensorRank> prob_stride = {0,0,0,0,0};
if constexpr (IsLoad) {
if constexpr (is_source_supported) {
@ -1106,9 +1103,6 @@ public:
}
else if constexpr (is_destination_supported) {
ElementD const* ptr_D = nullptr;
// tma_store_c should be a gmem_tensor, second argument should be a stride
Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group]));
cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d,
@ -1154,8 +1148,7 @@ public:
tensormaps_cp_fence_release(
TensorMapStorage& shared_tensormaps,
cute::TmaDescriptor const* tensormap,
[[maybe_unused]] uint32_t lane_predicate,
int32_t warp_group_idx = 0) {
const int32_t warp_group_idx = 0) {
// Entire warp must do this (ie its aligned)
if constexpr (IsLoad) {

View File

@ -674,14 +674,12 @@ struct CollectiveMma<
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
const uint32_t K = get<2>(problem_shape_mnkl);
// Only consider dimensions and strides that we need to recalculate and replace for each group
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
static_assert(TensorRank == Int<3>{},
"Descriptor modification for global dims & strides expects rank as 3.");
cute::array<uint32_t, TensorRank> prob_shape_A = {1,1,1};
cute::array<uint64_t, TensorRank> prob_stride_A = {0,0,0};
cute::array<uint32_t, TensorRank> prob_shape_B = {1,1,1};
cute::array<uint64_t, TensorRank> prob_stride_B = {0,0,0};
// Replace all dims for consistency
constexpr int MaxTensorRank = 5;
cute::array<uint32_t, MaxTensorRank> prob_shape_A = {1,1,1,1,1};
cute::array<uint64_t, MaxTensorRank> prob_stride_A = {0,0,0,0,0};
cute::array<uint32_t, MaxTensorRank> prob_shape_B = {1,1,1,1,1};
cute::array<uint64_t, MaxTensorRank> prob_stride_B = {0,0,0,0,0};
InternalElementA const* ptr_A = nullptr;
Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]);

View File

@ -499,7 +499,7 @@ public:
}
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
// Prepare and partition the input tensors. Expects a tuple of tensors where:
// get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
@ -595,7 +595,7 @@ public:
if (work_tile_info.is_valid() && did_batch_change) {
curr_batch = next_batch;
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1);
}
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
// Since this state is waiting for loads to finish, it must start in the inverted phase.
@ -644,7 +644,7 @@ public:
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0);
}
load_order_barrier.wait();
@ -657,7 +657,7 @@ public:
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
@ -692,7 +692,7 @@ public:
if (work_tile_info.is_valid() && did_batch_change) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
// tensormap update
@ -708,7 +708,7 @@ public:
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0);
}
}
@ -749,16 +749,15 @@ public:
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
epi_store_tensormap,
lane_predicate,
consumer_warp_group_idx);
}
}
while (work_tile_info.is_valid()) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
int32_t curr_batch = work_tile_info.L_idx;
@ -841,7 +840,7 @@ public:
did_batch_change = curr_batch != work_tile_info.L_idx;
if (work_tile_info.is_valid() && did_batch_change) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
if (warp_idx_in_warp_group == 0) {
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
@ -857,7 +856,6 @@ public:
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
epi_store_tensormap,
lane_predicate,
consumer_warp_group_idx);
}
}

View File

@ -514,7 +514,7 @@ public:
}
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
if (warp_group_role == WarpGroupRole::Consumer1) {
// Advance 2nd Math WG to the next work tile for the startup
@ -531,7 +531,7 @@ public:
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
// Prepare and partition the input tensors. Expects a tuple of tensors where:
@ -628,7 +628,7 @@ public:
if (work_tile_info.is_valid() && did_batch_change) {
curr_batch = next_batch;
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1);
}
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
// Since this state is waiting for loads to finish, it must start in the inverted phase.
@ -677,7 +677,7 @@ public:
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0);
}
load_order_barrier.wait();
@ -690,7 +690,7 @@ public:
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
@ -725,7 +725,7 @@ public:
if (work_tile_info.is_valid() && did_batch_change) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
// tensormap update
@ -741,7 +741,7 @@ public:
// Converge before issuing tensormap fence release since fence is aligned
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0);
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0);
}
}
@ -784,14 +784,13 @@ public:
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
epi_store_tensormap,
lane_predicate,
consumer_warp_group_idx);
}
}
while (work_tile_info.is_valid()) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
int32_t curr_batch = work_tile_info.L_idx;
@ -880,7 +879,7 @@ public:
// Skip a tile for pingpong
if (work_tile_info.is_valid()) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
mainloop_pipe_consumer_state.advance(work_k_tile_count);
@ -895,7 +894,7 @@ public:
did_batch_change = curr_batch != work_tile_info.L_idx;
if (work_tile_info.is_valid() && did_batch_change) {
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx);
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1);
}
if (warp_idx_in_warp_group == 0) {
collective_epilogue.tensormaps_perform_update<IsEpiLoad>(
@ -911,7 +910,6 @@ public:
__syncwarp();
collective_epilogue.tensormaps_cp_fence_release<IsEpiLoad>(shared_storage.tensormaps.epilogue,
epi_store_tensormap,
lane_predicate,
consumer_warp_group_idx);
}
}