diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index b6b3af22..95861402 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -270,7 +270,7 @@ struct DefaultIteratorsTensorOp< >; using WarpTileIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), WarpTileIteratorNotMixed, WarpTileIteratorMixed>::type; @@ -289,7 +289,7 @@ struct DefaultIteratorsTensorOp< >; using SharedLoadIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), SharedLoadIteratorNotMixed, SharedLoadIteratorMixed>::type; @@ -337,7 +337,7 @@ struct DefaultIteratorsTensorOp< >; using WarpTileIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), WarpTileIteratorNotMixed, WarpTileIteratorMixed>::type; @@ -356,7 +356,7 @@ struct DefaultIteratorsTensorOp< >; using SharedLoadIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), SharedLoadIteratorNotMixed, SharedLoadIteratorMixed>::type; @@ -404,7 +404,7 @@ struct DefaultIteratorsTensorOp< >; using WarpTileIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), WarpTileIteratorNotMixed, WarpTileIteratorMixed>::type; @@ -423,7 +423,7 @@ struct DefaultIteratorsTensorOp< >; using SharedLoadIterator = typename platform::conditional< - (ThreadblockShape::kN == 256), + (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8), SharedLoadIteratorNotMixed, SharedLoadIteratorMixed>::type;