Changes to iterators to support s8 gemm with f16 outputs (#812)

* Changes to iterators to support s8 gemm with f16 outputs

* should work

---------

Co-authored-by: Sujan Gonugondla <gsujan@amaon.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Sujan Kumar Gonugondla 2023-02-16 18:37:51 -05:00 committed by GitHub
parent 34bed24af3
commit d8359c804b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 7 deletions

View File

@ -224,6 +224,44 @@ struct DefaultIteratorsTensorOp<
static int const kFragmentsPerIteration = 2;
};
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<
half_t,
int32_t,
8,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
WarpShape,
InstructionShape,
int32_t,
32,
16,
8,
8
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
ThreadMap,
int32_t,
32,
16,
8,
8
>;
static int const kFragmentsPerIteration = 2;
};
/// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
/// Threadblock::kN = 256 still has bank conflicts.
template <

View File

@ -70,8 +70,9 @@ template <
int ElementSizeBits_, ///< Size of accumulator in bits
int OutputSizeBits_, ///< Size of output element in bits
int ElementsPerAccess, ///< Vector length of output vector
int ContiguousLanes ///< Number of lanes in the warp writing to contiguous elements
int ContiguousLanes, ///< Number of lanes in the warp writing to contiguous elements
/// in the global memory tensor
bool EightBitsOutputOrLess = (OutputSizeBits_ <= 8)
>
class SharedLoadIteratorMixed;
@ -85,7 +86,7 @@ template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_ ///< Accumulator data type
>
class SharedLoadIteratorMixed<ThreadMap_, Element_, 32, 16, 8, 8> {
class SharedLoadIteratorMixed<ThreadMap_, Element_, 32, 16, 8, 8, false> {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
@ -253,7 +254,7 @@ template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
int OutputSizeBits_ ///< Size of output element in bits
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 16, 8> {
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 16, 8, true> {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
@ -418,7 +419,7 @@ template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
int OutputSizeBits_
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 8, 8> {
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 8, 8, true> {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;

View File

@ -64,7 +64,8 @@ template <
int ElementSizeBits, ///< Size of accumulator element in bits
int OutputSizeBits, ///< Size of output element in bits
int OutputElementCount, ///< number of elements in output vector
int ContiguousLanes ///< Number of consecutive lanes writing to contiguous memory
int ContiguousLanes, ///< Number of consecutive lanes writing to contiguous memory
bool EightBitsOutputOrLess = (OutputSizeBits <= 8)
>
class TileIteratorTensorOpMixed {
public:
@ -319,7 +320,7 @@ template <
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape),
int OutputSizeBits ///< Size of output element in bits
>
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 16, 8> {
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 16, 8, true> {
public:
using WarpShape = WarpShape_;
@ -526,7 +527,7 @@ template <
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
int OutputSizeBits ///< Size of output element in bits
>
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 8, 8> {
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 8, 8, true> {
public:
using WarpShape = WarpShape_;