diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index 95861402..6c148742 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -224,6 +224,44 @@ struct DefaultIteratorsTensorOp< static int const kFragmentsPerIteration = 2; }; +/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template < + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + half_t, + int32_t, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + WarpShape, + InstructionShape, + int32_t, + 32, + 16, + 8, + 8 + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + int32_t, + 32, + 16, + 8, + 8 + >; + + static int const kFragmentsPerIteration = 2; +}; + /// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts. /// Threadblock::kN = 256 still has bank conflicts. template < diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h index a4711377..ff7f659c 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h @@ -70,8 +70,9 @@ template < int ElementSizeBits_, ///< Size of accumulator in bits int OutputSizeBits_, ///< Size of output element in bits int ElementsPerAccess, ///< Vector length of output vector - int ContiguousLanes ///< Number of lanes in the warp writing to contiguous elements + int ContiguousLanes, ///< Number of lanes in the warp writing to contiguous elements /// in the global memory tensor + bool EightBitsOutputOrLess = (OutputSizeBits_ <= 8) > class SharedLoadIteratorMixed; @@ -85,7 +86,7 @@ template < typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) typename Element_ ///< Accumulator data type > -class SharedLoadIteratorMixed { +class SharedLoadIteratorMixed { public: using ThreadMap = ThreadMap_; using Shape = typename ThreadMap::Shape; @@ -253,7 +254,7 @@ template < typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) int OutputSizeBits_ ///< Size of output element in bits > -class SharedLoadIteratorMixed { +class SharedLoadIteratorMixed { public: using ThreadMap = ThreadMap_; using Shape = typename ThreadMap::Shape; @@ -418,7 +419,7 @@ template < typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) int OutputSizeBits_ > -class SharedLoadIteratorMixed { +class SharedLoadIteratorMixed { public: using ThreadMap = ThreadMap_; using Shape = typename ThreadMap::Shape; diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h index 3bbc942e..74911b81 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -64,7 +64,8 @@ template < int ElementSizeBits, ///< Size of accumulator element in bits int OutputSizeBits, ///< Size of output element in bits int OutputElementCount, ///< number of elements in output vector - int ContiguousLanes ///< Number of consecutive lanes writing to contiguous memory + int ContiguousLanes, ///< Number of consecutive lanes writing to contiguous memory + bool EightBitsOutputOrLess = (OutputSizeBits <= 8) > class TileIteratorTensorOpMixed { public: @@ -319,7 +320,7 @@ template < typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape), int OutputSizeBits ///< Size of output element in bits > -class TileIteratorTensorOpMixed { +class TileIteratorTensorOpMixed { public: using WarpShape = WarpShape_; @@ -526,7 +527,7 @@ template < typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) int OutputSizeBits ///< Size of output element in bits > -class TileIteratorTensorOpMixed { +class TileIteratorTensorOpMixed { public: using WarpShape = WarpShape_;