diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index b8bf3f80..e277e4a4 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief + \brief */ #pragma once @@ -555,31 +555,73 @@ public: static Status can_implement( cutlass::gemm::GemmCoord const & problem_size) { - CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()"); + CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); - static int const kAlignmentA = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) ? 64 : Mma::IteratorA::AccessType::kElements; - - static int const kAlignmentB = (platform::is_same>::value) + static int const kAlignmentB = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) ? 64 : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; - if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); return Status::kErrorMisalignedOperand; }