From d6117ca362026cabb8485dda9bbb56abd4398435 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Wed, 7 Dec 2022 11:17:49 -0500 Subject: [PATCH] Relax stream K gemm alignment constraints (#717) * Relax stream K gemm alignment constraints The current alignment requirements are too strict. Make them identical to the checks for the regular universal gemm. * Revert "Relax stream K gemm alignment constraints" This reverts commit 31e80a250e2b0ac4bda2e4b437b39dc5bcd5e845. * Relax stream K gemm alignment constraints The current alignment requirements are too strict. Make them identical to the checks for the regular universal gemm. Co-authored-by: Haicheng Wu --- .../gemm/kernel/gemm_universal_streamk.h | 70 +++++++++++++++---- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index b8bf3f80..e277e4a4 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief + \brief */ #pragma once @@ -555,31 +555,73 @@ public: static Status can_implement( cutlass::gemm::GemmCoord const & problem_size) { - CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()"); + CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); - static int const kAlignmentA = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) ? 64 : Mma::IteratorA::AccessType::kElements; - - static int const kAlignmentB = (platform::is_same>::value) + static int const kAlignmentB = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) ? 64 : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; - if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || - (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || - (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); return Status::kErrorMisalignedOperand; }