Changes to iterators to support s8 gemm with f16 outputs (#812)

* Changes to iterators to support s8 gemm with f16 outputs

* should work

---------

Co-authored-by: Sujan Gonugondla <gsujan@amaon.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Sujan Kumar Gonugondla 2023-02-16 18:37:51 -05:00 committed by GitHub
parent 34bed24af3
commit d8359c804b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 7 deletions

View File

@ -224,6 +224,44 @@ struct DefaultIteratorsTensorOp<
static int const kFragmentsPerIteration = 2; static int const kFragmentsPerIteration = 2;
}; };
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<
half_t,
int32_t,
8,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
WarpShape,
InstructionShape,
int32_t,
32,
16,
8,
8
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
ThreadMap,
int32_t,
32,
16,
8,
8
>;
static int const kFragmentsPerIteration = 2;
};
/// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts. /// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
/// Threadblock::kN = 256 still has bank conflicts. /// Threadblock::kN = 256 still has bank conflicts.
template < template <

View File

@ -70,8 +70,9 @@ template <
int ElementSizeBits_, ///< Size of accumulator in bits int ElementSizeBits_, ///< Size of accumulator in bits
int OutputSizeBits_, ///< Size of output element in bits int OutputSizeBits_, ///< Size of output element in bits
int ElementsPerAccess, ///< Vector length of output vector int ElementsPerAccess, ///< Vector length of output vector
int ContiguousLanes ///< Number of lanes in the warp writing to contiguous elements int ContiguousLanes, ///< Number of lanes in the warp writing to contiguous elements
/// in the global memory tensor /// in the global memory tensor
bool EightBitsOutputOrLess = (OutputSizeBits_ <= 8)
> >
class SharedLoadIteratorMixed; class SharedLoadIteratorMixed;
@ -85,7 +86,7 @@ template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_ ///< Accumulator data type typename Element_ ///< Accumulator data type
> >
class SharedLoadIteratorMixed<ThreadMap_, Element_, 32, 16, 8, 8> { class SharedLoadIteratorMixed<ThreadMap_, Element_, 32, 16, 8, 8, false> {
public: public:
using ThreadMap = ThreadMap_; using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape; using Shape = typename ThreadMap::Shape;
@ -253,7 +254,7 @@ template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
int OutputSizeBits_ ///< Size of output element in bits int OutputSizeBits_ ///< Size of output element in bits
> >
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 16, 8> { class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 16, 8, true> {
public: public:
using ThreadMap = ThreadMap_; using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape; using Shape = typename ThreadMap::Shape;
@ -418,7 +419,7 @@ template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
int OutputSizeBits_ int OutputSizeBits_
> >
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 8, 8> { class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 8, 8, true> {
public: public:
using ThreadMap = ThreadMap_; using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape; using Shape = typename ThreadMap::Shape;

View File

@ -64,7 +64,8 @@ template <
int ElementSizeBits, ///< Size of accumulator element in bits int ElementSizeBits, ///< Size of accumulator element in bits
int OutputSizeBits, ///< Size of output element in bits int OutputSizeBits, ///< Size of output element in bits
int OutputElementCount, ///< number of elements in output vector int OutputElementCount, ///< number of elements in output vector
int ContiguousLanes ///< Number of consecutive lanes writing to contiguous memory int ContiguousLanes, ///< Number of consecutive lanes writing to contiguous memory
bool EightBitsOutputOrLess = (OutputSizeBits <= 8)
> >
class TileIteratorTensorOpMixed { class TileIteratorTensorOpMixed {
public: public:
@ -319,7 +320,7 @@ template <
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape), typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape),
int OutputSizeBits ///< Size of output element in bits int OutputSizeBits ///< Size of output element in bits
> >
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 16, 8> { class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 16, 8, true> {
public: public:
using WarpShape = WarpShape_; using WarpShape = WarpShape_;
@ -526,7 +527,7 @@ template <
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
int OutputSizeBits ///< Size of output element in bits int OutputSizeBits ///< Size of output element in bits
> >
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 8, 8> { class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 8, 8, true> {
public: public:
using WarpShape = WarpShape_; using WarpShape = WarpShape_;