From 53668799b2e38d3bb4d8245e949301476344fc2c Mon Sep 17 00:00:00 2001 From: Saagar Jha Date: Mon, 14 Oct 2024 16:46:20 -0700 Subject: [PATCH] Handle MNK Sm90{Row, Col}Reduction problem shapes (#1803) --- ...sm90_visitor_store_tma_warpspecialized.hpp | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 060f8d15..f9ebe739 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -523,7 +523,8 @@ public: CudaHostAdapter* cuda_adapter = nullptr) { #if !defined(CUTLASS_SKIP_REDUCTION_INIT) if constexpr (IsAtomic) { - auto [M, N, K, L] = problem_shape; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); if (args.ptr_scalar != nullptr) { return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter); @@ -700,7 +701,9 @@ public: reduction_buffer = nullptr; } else if constexpr (FinalReduction) { - auto [M, N, K, L] = problem_shape; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -735,7 +738,8 @@ public: } size_t workspace_size = 0; - auto [M, N, K, L] = problem_shape; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; // Increment by size of reduction buffer workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); @@ -750,8 +754,9 @@ public: initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { #if !defined(CUTLASS_SKIP_REDUCTION_INIT) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; if constexpr (IsAtomic) { - auto [M, N, K, L] = problem_shape; Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); if (args.ptr_row != nullptr) { return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); @@ -761,7 +766,6 @@ public: else #endif if constexpr (FinalReduction) { - auto [M, N, K, L] = problem_shape; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -1290,7 +1294,9 @@ public: reduction_buffer = nullptr; } else if constexpr (FinalReduction) { - auto [M, N, K, L] = problem_shape; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -1325,7 +1331,8 @@ public: } size_t workspace_size = 0; - auto [M, N, K, L] = problem_shape; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; // Increment by size of reduction buffer @@ -1342,8 +1349,9 @@ public: initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { #if !defined(CUTLASS_SKIP_REDUCTION_INIT) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; if constexpr (IsAtomic) { - auto [M, N, K, L] = problem_shape; Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); if (args.ptr_col != nullptr) { return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); @@ -1353,7 +1361,6 @@ public: else #endif if constexpr (FinalReduction) { - auto [M, N, K, L] = problem_shape; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);