From e8a8b693657a98527004ba1e75847ff0ca15b5da Mon Sep 17 00:00:00 2001 From: azhurkevich <101208641+azhurkevich@users.noreply.github.com> Date: Fri, 25 Oct 2024 17:14:01 -0700 Subject: [PATCH] Refactor some GroupedGEMM logic (#1899) --- include/cute/arch/copy_sm90_desc.hpp | 22 +++++++++++++++++-- .../cutlass/epilogue/collective/detail.hpp | 1 - ...m90_epilogue_array_tma_warpspecialized.hpp | 17 +++++--------- ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 14 +++++------- ..._array_tma_warpspecialized_cooperative.hpp | 20 ++++++++--------- ...emm_array_tma_warpspecialized_pingpong.hpp | 22 +++++++++---------- 6 files changed, 50 insertions(+), 46 deletions(-) diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 25a252a8..cc0bf4a3 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -323,8 +323,8 @@ tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc, CUTE_HOST_DEVICE void tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc, - cute::array const& prob_shape, - cute::array const& prob_stride) + cute::array const& prob_shape, + cute::array const& prob_stride) { #if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); @@ -341,6 +341,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor asm volatile ( "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" :: "l"(smem_int64_desc), "r"(prob_shape[2])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 3, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[3])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 4, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[4])); // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) asm volatile ( @@ -349,6 +355,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor asm volatile ( "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[2])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[3])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[4])); #else // 4 LSBs are not included asm volatile ( @@ -357,6 +369,12 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor asm volatile ( "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[3] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[4] >> 4)); #endif #else CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index a6e13bc7..6c0368e0 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -456,7 +456,6 @@ public: tensormaps_cp_fence_release( [[maybe_unused]] TensorMapStorage& shared_tensormaps, [[maybe_unused]] cute::TmaDescriptor const* tensormap, - [[maybe_unused]] uint32_t lane_predicate, [[maybe_unused]] int32_t warp_group_idx) { } template diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 56bdd843..84b6e14e 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -1080,13 +1080,10 @@ public: int32_t warp_group_idx) { const uint32_t M = get<0>(problem_shape_mnkl); const uint32_t N = get<1>(problem_shape_mnkl); - // Only consider dimensions and strides that we need to recalculate and replace for each group - constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N - static_assert(TensorRank == Int<3>{}, - "Descriptor modification for global dims & strides expects rank as 3."); - - cute::array prob_shape = {1,1,1}; - cute::array prob_stride = {0,0,0}; + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape = {1,1,1,1,1}; + cute::array prob_stride = {0,0,0,0,0}; if constexpr (IsLoad) { if constexpr (is_source_supported) { @@ -1106,9 +1103,6 @@ public: } else if constexpr (is_destination_supported) { ElementD const* ptr_D = nullptr; - - // tma_store_c should be a gmem_tensor, second argument should be a stride - Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, @@ -1154,8 +1148,7 @@ public: tensormaps_cp_fence_release( TensorMapStorage& shared_tensormaps, cute::TmaDescriptor const* tensormap, - [[maybe_unused]] uint32_t lane_predicate, - int32_t warp_group_idx = 0) { + const int32_t warp_group_idx = 0) { // Entire warp must do this (ie its aligned) if constexpr (IsLoad) { diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 9825a165..628750fc 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -674,14 +674,12 @@ struct CollectiveMma< const uint32_t M = get<0>(problem_shape_mnkl); const uint32_t N = get<1>(problem_shape_mnkl); const uint32_t K = get<2>(problem_shape_mnkl); - // Only consider dimensions and strides that we need to recalculate and replace for each group - constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N - static_assert(TensorRank == Int<3>{}, - "Descriptor modification for global dims & strides expects rank as 3."); - cute::array prob_shape_A = {1,1,1}; - cute::array prob_stride_A = {0,0,0}; - cute::array prob_shape_B = {1,1,1}; - cute::array prob_stride_B = {0,0,0}; + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; InternalElementA const* ptr_A = nullptr; Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 961dcb8a..823e919e 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -499,7 +499,7 @@ public: } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); // Prepare and partition the input tensors. Expects a tuple of tensors where: // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) @@ -595,7 +595,7 @@ public: if (work_tile_info.is_valid() && did_batch_change) { curr_batch = next_batch; if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1); } // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates // Since this state is waiting for loads to finish, it must start in the inverted phase. @@ -644,7 +644,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } load_order_barrier.wait(); @@ -657,7 +657,7 @@ public: if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape @@ -692,7 +692,7 @@ public: if (work_tile_info.is_valid() && did_batch_change) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } // tensormap update @@ -708,7 +708,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } } @@ -749,16 +749,15 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, - lane_predicate, consumer_warp_group_idx); } } while (work_tile_info.is_valid()) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } int32_t curr_batch = work_tile_info.L_idx; @@ -841,7 +840,7 @@ public: did_batch_change = curr_batch != work_tile_info.L_idx; if (work_tile_info.is_valid() && did_batch_change) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } if (warp_idx_in_warp_group == 0) { collective_epilogue.tensormaps_perform_update( @@ -857,7 +856,6 @@ public: __syncwarp(); collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, - lane_predicate, consumer_warp_group_idx); } } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index d58d4d61..38633764 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -514,7 +514,7 @@ public: } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); if (warp_group_role == WarpGroupRole::Consumer1) { // Advance 2nd Math WG to the next work tile for the startup @@ -531,7 +531,7 @@ public: epi_load_pipe_consumer_state.advance(c_tile_count); epi_store_pipe_producer_state.advance(d_tile_count); - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } // Prepare and partition the input tensors. Expects a tuple of tensors where: @@ -628,7 +628,7 @@ public: if (work_tile_info.is_valid() && did_batch_change) { curr_batch = next_batch; if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), curr_batch); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(curr_batch), 1); } // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates // Since this state is waiting for loads to finish, it must start in the inverted phase. @@ -677,7 +677,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } load_order_barrier.wait(); @@ -690,7 +690,7 @@ public: if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape @@ -725,7 +725,7 @@ public: if (work_tile_info.is_valid() && did_batch_change) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } // tensormap update @@ -741,7 +741,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, lane_predicate, 0); + collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } } @@ -784,14 +784,13 @@ public: __syncwarp(); collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, - lane_predicate, consumer_warp_group_idx); } } while (work_tile_info.is_valid()) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } int32_t curr_batch = work_tile_info.L_idx; @@ -880,7 +879,7 @@ public: // Skip a tile for pingpong if (work_tile_info.is_valid()) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); mainloop_pipe_consumer_state.advance(work_k_tile_count); @@ -895,7 +894,7 @@ public: did_batch_change = curr_batch != work_tile_info.L_idx; if (work_tile_info.is_valid() && did_batch_change) { if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), work_tile_info.L_idx); + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } if (warp_idx_in_warp_group == 0) { collective_epilogue.tensormaps_perform_update( @@ -911,7 +910,6 @@ public: __syncwarp(); collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, - lane_predicate, consumer_warp_group_idx); } }