Relax stream K gemm alignment constraints (#717)
* Relax stream K gemm alignment constraints The current alignment requirements are too strict. Make them identical to the checks for the regular universal gemm. * Revert "Relax stream K gemm alignment constraints" This reverts commit 31e80a250e2b0ac4bda2e4b437b39dc5bcd5e845. * Relax stream K gemm alignment constraints The current alignment requirements are too strict. Make them identical to the checks for the regular universal gemm. Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
9c0518608e
commit
d6117ca362
@ -555,31 +555,73 @@ public:
|
|||||||
static Status can_implement(
|
static Status can_implement(
|
||||||
cutlass::gemm::GemmCoord const & problem_size)
|
cutlass::gemm::GemmCoord const & problem_size)
|
||||||
{
|
{
|
||||||
CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");
|
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
|
||||||
|
|
||||||
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
|
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||||
layout::ColumnMajorInterleaved<32>>::value)
|
layout::ColumnMajorInterleaved<32>>::value)
|
||||||
? 32
|
? 32
|
||||||
: (platform::is_same<typename Mma::IteratorA::Layout,
|
: (platform::is_same<LayoutA,
|
||||||
layout::ColumnMajorInterleaved<64>>::value)
|
layout::ColumnMajorInterleaved<64>>::value)
|
||||||
? 64
|
? 64
|
||||||
: Mma::IteratorA::AccessType::kElements;
|
: Mma::IteratorA::AccessType::kElements;
|
||||||
|
static int const kAlignmentB = (platform::is_same<LayoutB,
|
||||||
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
|
layout::RowMajorInterleaved<32>>::value)
|
||||||
layout::RowMajorInterleaved<32>>::value)
|
|
||||||
? 32
|
? 32
|
||||||
: (platform::is_same<typename Mma::IteratorB::Layout,
|
: (platform::is_same<LayoutB,
|
||||||
layout::RowMajorInterleaved<64>>::value)
|
layout::RowMajorInterleaved<64>>::value)
|
||||||
? 64
|
? 64
|
||||||
: Mma::IteratorB::AccessType::kElements;
|
: Mma::IteratorB::AccessType::kElements;
|
||||||
|
static int const kAlignmentC = (platform::is_same<LayoutC,
|
||||||
|
layout::ColumnMajorInterleaved<32>>::value)
|
||||||
|
? 32
|
||||||
|
: (platform::is_same<LayoutC,
|
||||||
|
layout::ColumnMajorInterleaved<64>>::value)
|
||||||
|
? 64
|
||||||
|
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||||
|
|
||||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
bool isAMisaligned = false;
|
||||||
|
bool isBMisaligned = false;
|
||||||
|
bool isCMisaligned = false;
|
||||||
|
|
||||||
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
|
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
|
||||||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
|
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC))
|
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
|
||||||
{
|
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand");
|
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||||
|
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
|
||||||
|
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
|
||||||
|
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||||
|
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
|
||||||
|
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||||
|
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||||
|
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
|
||||||
|
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
|
||||||
|
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||||
|
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
|
||||||
|
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||||
|
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||||
|
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
|
||||||
|
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isAMisaligned) {
|
||||||
|
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||||
|
return Status::kErrorMisalignedOperand;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isBMisaligned) {
|
||||||
|
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||||
|
return Status::kErrorMisalignedOperand;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isCMisaligned) {
|
||||||
|
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||||
return Status::kErrorMisalignedOperand;
|
return Status::kErrorMisalignedOperand;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user