Re-enable all alignments for int accumulators (#807)

This commit is contained in:
Jack Kosaian 2023-02-06 22:01:15 -05:00 committed by GitHub
parent add4ba622f
commit 5921043981
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -136,14 +136,15 @@ struct DefaultIteratorsTensorOp<float, float, 4, ThreadblockShape, WarpShape, In
static int const kFragmentsPerIteration = 2; static int const kFragmentsPerIteration = 2;
}; };
/// Partial specialization for int32_t <= int32_t x 4 /// Partial specialization for int32_t <= int32_t
template < template <
int ElementsPerAccess,
typename ThreadblockShape, typename ThreadblockShape,
typename WarpShape, typename WarpShape,
typename InstructionShape, typename InstructionShape,
typename ThreadMap typename ThreadMap
> >
struct DefaultIteratorsTensorOp<int32_t, int32_t, 4, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> { struct DefaultIteratorsTensorOp<int32_t, int32_t, ElementsPerAccess, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape, WarpShape,
@ -160,14 +161,15 @@ struct DefaultIteratorsTensorOp<int32_t, int32_t, 4, ThreadblockShape, WarpShape
static int const kFragmentsPerIteration = 1; static int const kFragmentsPerIteration = 1;
}; };
/// Partial specialization for float <= int32_t x 4 /// Partial specialization for float <= int32_t
template < template <
int ElementsPerAccess,
typename ThreadblockShape, typename ThreadblockShape,
typename WarpShape, typename WarpShape,
typename InstructionShape, typename InstructionShape,
typename ThreadMap typename ThreadMap
> >
struct DefaultIteratorsTensorOp<float, int32_t, 4, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> { struct DefaultIteratorsTensorOp<float, int32_t, ElementsPerAccess, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape, WarpShape,