Relax stream K gemm alignment constraints (#717)

* Relax stream K gemm alignment constraints

The current alignment requirements are too strict. Make them identical
to the checks for the regular universal gemm.

* Revert "Relax stream K gemm alignment constraints"

This reverts commit 31e80a250e2b0ac4bda2e4b437b39dc5bcd5e845.

* Relax stream K gemm alignment constraints

The current alignment requirements are too strict. Make them identical
to the checks for the regular universal gemm.

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Mike Iovine 2022-12-07 11:17:49 -05:00 committed by GitHub
parent 9c0518608e
commit d6117ca362
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -30,7 +30,7 @@
**************************************************************************************************/
/*! \file
\brief
\brief
*/
#pragma once
@ -555,31 +555,73 @@ public:
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size)
{
CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
static int const kAlignmentA = (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorA::Layout,
: (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<32>>::value)
static int const kAlignmentB = (platform::is_same<LayoutB,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorB::Layout,
: (platform::is_same<LayoutB,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<LayoutC,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC))
{
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}