From 56b46e2d13875b46b8f6a03f9f5ac91e2bfdc01a Mon Sep 17 00:00:00 2001 From: Chengquan Jiang Date: Wed, 10 Jul 2024 23:55:22 +0800 Subject: [PATCH] Fix grouped gemm invalid memory access to problem shapes (#1543) --- ...m90_gemm_array_tma_warpspecialized_cooperative.hpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index fefbf831..1c314aad 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -465,6 +465,9 @@ public: TileScheduler scheduler{params.scheduler}; auto work_tile_info = scheduler.get_current_work(); + if (not work_tile_info.is_valid()) { + return; + } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); @@ -623,7 +626,9 @@ public: // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + if (work_tile_info.is_valid()) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } } } // Scheduler work fetch loop @@ -702,7 +707,9 @@ public: // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + if (work_tile_info.is_valid()) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } } } // Scheduler work fetch loop