diff --git a/include/cutlass/gemm/kernel/gemm_grouped.h b/include/cutlass/gemm/kernel/gemm_grouped.h index 8b68c9d3..c9fe2c3e 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped.h +++ b/include/cutlass/gemm/kernel/gemm_grouped.h @@ -546,6 +546,9 @@ public: // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + // Compute threadblock-scoped matrix multiply-add mma( gemm_k_iterations,