Changes to iterators to support s8 gemm with f16 outputs (#812)
* Changes to iterators to support s8 gemm with f16 outputs * should work --------- Co-authored-by: Sujan Gonugondla <gsujan@amaon.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
34bed24af3
commit
d8359c804b
@ -224,6 +224,44 @@ struct DefaultIteratorsTensorOp<
|
||||
static int const kFragmentsPerIteration = 2;
|
||||
};
|
||||
|
||||
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp<
|
||||
half_t,
|
||||
int32_t,
|
||||
8,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
int32_t,
|
||||
32,
|
||||
16,
|
||||
8,
|
||||
8
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
||||
ThreadMap,
|
||||
int32_t,
|
||||
32,
|
||||
16,
|
||||
8,
|
||||
8
|
||||
>;
|
||||
|
||||
static int const kFragmentsPerIteration = 2;
|
||||
};
|
||||
|
||||
/// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
|
||||
/// Threadblock::kN = 256 still has bank conflicts.
|
||||
template <
|
||||
|
@ -70,8 +70,9 @@ template <
|
||||
int ElementSizeBits_, ///< Size of accumulator in bits
|
||||
int OutputSizeBits_, ///< Size of output element in bits
|
||||
int ElementsPerAccess, ///< Vector length of output vector
|
||||
int ContiguousLanes ///< Number of lanes in the warp writing to contiguous elements
|
||||
int ContiguousLanes, ///< Number of lanes in the warp writing to contiguous elements
|
||||
/// in the global memory tensor
|
||||
bool EightBitsOutputOrLess = (OutputSizeBits_ <= 8)
|
||||
>
|
||||
class SharedLoadIteratorMixed;
|
||||
|
||||
@ -85,7 +86,7 @@ template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_ ///< Accumulator data type
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, Element_, 32, 16, 8, 8> {
|
||||
class SharedLoadIteratorMixed<ThreadMap_, Element_, 32, 16, 8, 8, false> {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
@ -253,7 +254,7 @@ template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
int OutputSizeBits_ ///< Size of output element in bits
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 16, 8> {
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 16, 8, true> {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
@ -418,7 +419,7 @@ template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
int OutputSizeBits_
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 8, 8> {
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 8, 8, true> {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
|
@ -64,7 +64,8 @@ template <
|
||||
int ElementSizeBits, ///< Size of accumulator element in bits
|
||||
int OutputSizeBits, ///< Size of output element in bits
|
||||
int OutputElementCount, ///< number of elements in output vector
|
||||
int ContiguousLanes ///< Number of consecutive lanes writing to contiguous memory
|
||||
int ContiguousLanes, ///< Number of consecutive lanes writing to contiguous memory
|
||||
bool EightBitsOutputOrLess = (OutputSizeBits <= 8)
|
||||
>
|
||||
class TileIteratorTensorOpMixed {
|
||||
public:
|
||||
@ -319,7 +320,7 @@ template <
|
||||
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape),
|
||||
int OutputSizeBits ///< Size of output element in bits
|
||||
>
|
||||
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 16, 8> {
|
||||
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 16, 8, true> {
|
||||
public:
|
||||
|
||||
using WarpShape = WarpShape_;
|
||||
@ -526,7 +527,7 @@ template <
|
||||
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
|
||||
int OutputSizeBits ///< Size of output element in bits
|
||||
>
|
||||
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 8, 8> {
|
||||
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 8, 8, true> {
|
||||
public:
|
||||
|
||||
using WarpShape = WarpShape_;
|
||||
|
Loading…
Reference in New Issue
Block a user