diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h index 871a6604..4d62b914 100644 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h @@ -425,11 +425,49 @@ public: static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); return Status::kErrorMisalignedOperand; }