update float < int32_t * 4 (#488)

Co-authored-by: 赵俊涛 <zhaojuntao@zhaojuntaos-MacBook-Pro.local>
This commit is contained in:
TonyZhao 2022-05-05 01:36:05 +08:00 committed by GitHub
parent ec2b4fd85d
commit ddd8f9cf41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -158,6 +158,30 @@ struct DefaultIteratorsTensorOp<int32_t, int32_t, 4, ThreadblockShape, WarpShape
static int const kFragmentsPerIteration = 1; static int const kFragmentsPerIteration = 1;
}; };
/// Partial specialization for float <= int32_t x 4
template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<float, int32_t, 4, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape,
InstructionShape,
int32_t,
layout::RowMajor
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
ThreadMap,
int32_t
>;
static int const kFragmentsPerIteration = 1;
};
/// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts. /// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts.
template < template <
typename ThreadblockShape, typename ThreadblockShape,