diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index fefbf831..1c314aad 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -465,6 +465,9 @@ public: TileScheduler scheduler{params.scheduler}; auto work_tile_info = scheduler.get_current_work(); + if (not work_tile_info.is_valid()) { + return; + } // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); @@ -623,7 +626,9 @@ public: // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + if (work_tile_info.is_valid()) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } } } // Scheduler work fetch loop @@ -702,7 +707,9 @@ public: // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); if constexpr (IsGroupedGemmKernel) { - problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + if (work_tile_info.is_valid()) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } } } // Scheduler work fetch loop