From 9fb38ac048348e39d2780e0e77ce501d14095221 Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Wed, 15 Feb 2023 12:06:00 -0500 Subject: [PATCH] fix alignmentC=8 for imma N=128 (#822) Co-authored-by: Haicheng Wu --- .../threadblock/default_epilogue_tensor_op.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index b6b3af22..95861402 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -270,7 +270,7 @@ struct DefaultIteratorsTensorOp< >; using WarpTileIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), WarpTileIteratorNotMixed, WarpTileIteratorMixed>::type; @@ -289,7 +289,7 @@ struct DefaultIteratorsTensorOp< >; using SharedLoadIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), SharedLoadIteratorNotMixed, SharedLoadIteratorMixed>::type; @@ -337,7 +337,7 @@ struct DefaultIteratorsTensorOp< >; using WarpTileIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), WarpTileIteratorNotMixed, WarpTileIteratorMixed>::type; @@ -356,7 +356,7 @@ struct DefaultIteratorsTensorOp< >; using SharedLoadIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), SharedLoadIteratorNotMixed, SharedLoadIteratorMixed>::type; @@ -404,7 +404,7 @@ struct DefaultIteratorsTensorOp< >; using WarpTileIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), WarpTileIteratorNotMixed, WarpTileIteratorMixed>::type; @@ -423,7 +423,7 @@ struct DefaultIteratorsTensorOp< >; using SharedLoadIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), SharedLoadIteratorNotMixed, SharedLoadIteratorMixed>::type;