fix alignmentC=8 for imma N=128 (#822)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2023-02-15 12:06:00 -05:00 committed by GitHub
parent 8f5c242426
commit 9fb38ac048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -270,7 +270,7 @@ struct DefaultIteratorsTensorOp<
>; >;
using WarpTileIterator = typename platform::conditional< using WarpTileIterator = typename platform::conditional<
(ThreadblockShape::kN == 256), (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8),
WarpTileIteratorNotMixed, WarpTileIteratorNotMixed,
WarpTileIteratorMixed>::type; WarpTileIteratorMixed>::type;
@ -289,7 +289,7 @@ struct DefaultIteratorsTensorOp<
>; >;
using SharedLoadIterator = typename platform::conditional< using SharedLoadIterator = typename platform::conditional<
(ThreadblockShape::kN == 256), (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8),
SharedLoadIteratorNotMixed, SharedLoadIteratorNotMixed,
SharedLoadIteratorMixed>::type; SharedLoadIteratorMixed>::type;
@ -337,7 +337,7 @@ struct DefaultIteratorsTensorOp<
>; >;
using WarpTileIterator = typename platform::conditional< using WarpTileIterator = typename platform::conditional<
(ThreadblockShape::kN == 256), (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8),
WarpTileIteratorNotMixed, WarpTileIteratorNotMixed,
WarpTileIteratorMixed>::type; WarpTileIteratorMixed>::type;
@ -356,7 +356,7 @@ struct DefaultIteratorsTensorOp<
>; >;
using SharedLoadIterator = typename platform::conditional< using SharedLoadIterator = typename platform::conditional<
(ThreadblockShape::kN == 256), (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8),
SharedLoadIteratorNotMixed, SharedLoadIteratorNotMixed,
SharedLoadIteratorMixed>::type; SharedLoadIteratorMixed>::type;
@ -404,7 +404,7 @@ struct DefaultIteratorsTensorOp<
>; >;
using WarpTileIterator = typename platform::conditional< using WarpTileIterator = typename platform::conditional<
(ThreadblockShape::kN == 256), (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8),
WarpTileIteratorNotMixed, WarpTileIteratorNotMixed,
WarpTileIteratorMixed>::type; WarpTileIteratorMixed>::type;
@ -423,7 +423,7 @@ struct DefaultIteratorsTensorOp<
>; >;
using SharedLoadIterator = typename platform::conditional< using SharedLoadIterator = typename platform::conditional<
(ThreadblockShape::kN == 256), (ThreadblockShape::kN == 256) || (ThreadblockShape::kN == 128 && ElementsPerAccess == 8),
SharedLoadIteratorNotMixed, SharedLoadIteratorNotMixed,
SharedLoadIteratorMixed>::type; SharedLoadIteratorMixed>::type;