Changes to iterators to support s8 gemm with f16 outputs (#812)
* Changes to iterators to support s8 gemm with f16 outputs * should work --------- Co-authored-by: Sujan Gonugondla <gsujan@amaon.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
34bed24af3
commit
d8359c804b
@ -224,6 +224,44 @@ struct DefaultIteratorsTensorOp<
|
|||||||
static int const kFragmentsPerIteration = 2;
|
static int const kFragmentsPerIteration = 2;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||||
|
template <
|
||||||
|
typename ThreadblockShape,
|
||||||
|
typename WarpShape,
|
||||||
|
typename InstructionShape,
|
||||||
|
typename ThreadMap
|
||||||
|
>
|
||||||
|
struct DefaultIteratorsTensorOp<
|
||||||
|
half_t,
|
||||||
|
int32_t,
|
||||||
|
8,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
ThreadMap> {
|
||||||
|
|
||||||
|
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
int32_t,
|
||||||
|
32,
|
||||||
|
16,
|
||||||
|
8,
|
||||||
|
8
|
||||||
|
>;
|
||||||
|
|
||||||
|
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
||||||
|
ThreadMap,
|
||||||
|
int32_t,
|
||||||
|
32,
|
||||||
|
16,
|
||||||
|
8,
|
||||||
|
8
|
||||||
|
>;
|
||||||
|
|
||||||
|
static int const kFragmentsPerIteration = 2;
|
||||||
|
};
|
||||||
|
|
||||||
/// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
|
/// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
|
||||||
/// Threadblock::kN = 256 still has bank conflicts.
|
/// Threadblock::kN = 256 still has bank conflicts.
|
||||||
template <
|
template <
|
||||||
|
@ -70,8 +70,9 @@ template <
|
|||||||
int ElementSizeBits_, ///< Size of accumulator in bits
|
int ElementSizeBits_, ///< Size of accumulator in bits
|
||||||
int OutputSizeBits_, ///< Size of output element in bits
|
int OutputSizeBits_, ///< Size of output element in bits
|
||||||
int ElementsPerAccess, ///< Vector length of output vector
|
int ElementsPerAccess, ///< Vector length of output vector
|
||||||
int ContiguousLanes ///< Number of lanes in the warp writing to contiguous elements
|
int ContiguousLanes, ///< Number of lanes in the warp writing to contiguous elements
|
||||||
/// in the global memory tensor
|
/// in the global memory tensor
|
||||||
|
bool EightBitsOutputOrLess = (OutputSizeBits_ <= 8)
|
||||||
>
|
>
|
||||||
class SharedLoadIteratorMixed;
|
class SharedLoadIteratorMixed;
|
||||||
|
|
||||||
@ -85,7 +86,7 @@ template <
|
|||||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||||
typename Element_ ///< Accumulator data type
|
typename Element_ ///< Accumulator data type
|
||||||
>
|
>
|
||||||
class SharedLoadIteratorMixed<ThreadMap_, Element_, 32, 16, 8, 8> {
|
class SharedLoadIteratorMixed<ThreadMap_, Element_, 32, 16, 8, 8, false> {
|
||||||
public:
|
public:
|
||||||
using ThreadMap = ThreadMap_;
|
using ThreadMap = ThreadMap_;
|
||||||
using Shape = typename ThreadMap::Shape;
|
using Shape = typename ThreadMap::Shape;
|
||||||
@ -253,7 +254,7 @@ template <
|
|||||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||||
int OutputSizeBits_ ///< Size of output element in bits
|
int OutputSizeBits_ ///< Size of output element in bits
|
||||||
>
|
>
|
||||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 16, 8> {
|
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 16, 8, true> {
|
||||||
public:
|
public:
|
||||||
using ThreadMap = ThreadMap_;
|
using ThreadMap = ThreadMap_;
|
||||||
using Shape = typename ThreadMap::Shape;
|
using Shape = typename ThreadMap::Shape;
|
||||||
@ -418,7 +419,7 @@ template <
|
|||||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||||
int OutputSizeBits_
|
int OutputSizeBits_
|
||||||
>
|
>
|
||||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 8, 8> {
|
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, OutputSizeBits_, 8, 8, true> {
|
||||||
public:
|
public:
|
||||||
using ThreadMap = ThreadMap_;
|
using ThreadMap = ThreadMap_;
|
||||||
using Shape = typename ThreadMap::Shape;
|
using Shape = typename ThreadMap::Shape;
|
||||||
|
@ -64,7 +64,8 @@ template <
|
|||||||
int ElementSizeBits, ///< Size of accumulator element in bits
|
int ElementSizeBits, ///< Size of accumulator element in bits
|
||||||
int OutputSizeBits, ///< Size of output element in bits
|
int OutputSizeBits, ///< Size of output element in bits
|
||||||
int OutputElementCount, ///< number of elements in output vector
|
int OutputElementCount, ///< number of elements in output vector
|
||||||
int ContiguousLanes ///< Number of consecutive lanes writing to contiguous memory
|
int ContiguousLanes, ///< Number of consecutive lanes writing to contiguous memory
|
||||||
|
bool EightBitsOutputOrLess = (OutputSizeBits <= 8)
|
||||||
>
|
>
|
||||||
class TileIteratorTensorOpMixed {
|
class TileIteratorTensorOpMixed {
|
||||||
public:
|
public:
|
||||||
@ -319,7 +320,7 @@ template <
|
|||||||
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape),
|
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape),
|
||||||
int OutputSizeBits ///< Size of output element in bits
|
int OutputSizeBits ///< Size of output element in bits
|
||||||
>
|
>
|
||||||
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 16, 8> {
|
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 16, 8, true> {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using WarpShape = WarpShape_;
|
using WarpShape = WarpShape_;
|
||||||
@ -526,7 +527,7 @@ template <
|
|||||||
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
|
typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape)
|
||||||
int OutputSizeBits ///< Size of output element in bits
|
int OutputSizeBits ///< Size of output element in bits
|
||||||
>
|
>
|
||||||
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 8, 8> {
|
class TileIteratorTensorOpMixed<WarpShape_, OperatorShape_, int32_t, 32, OutputSizeBits, 8, 8, true> {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using WarpShape = WarpShape_;
|
using WarpShape = WarpShape_;
|
||||||
|
Loading…
Reference in New Issue
Block a user