From dadc881a9606f95cba1b20acda03c9d07c286239 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Fri, 30 Sep 2022 07:00:38 -0700 Subject: [PATCH] Bug fix for gemm broadcast (#650) * gemm_universal_with_broadcast, +2 sources. * Revert "gemm_universal_with_broadcast, +2 sources." This reverts commit fb063251f2144a091f12c9abfce7e1713f2d1c9e. * gemm broadcast bug fix --- .../gemm/kernel/gemm_with_fused_epilogue_v2.h | 46 +++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h index 871a6604..4d62b914 100644 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h @@ -425,11 +425,49 @@ public: static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); return Status::kErrorMisalignedOperand; }