/*************************************************************************************************** * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used * to endorse or promote products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Defies structural properties of mixed-precision integer GEMM. Multiplicands are assumed to be packed 8bit integers, accumulators are assumed to be 32b signed integers, and output formats vary. */ #pragma once #include "cutlass/convert.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/gemm_epilogue.h" #include "cutlass/gemm/gemm_epilogue_traits.h" #include "cutlass/gemm/gemm_global_tile.h" #include "cutlass/gemm/gemm_shared_tile.h" #include "cutlass/gemm/gemm_traits.h" #include "cutlass/gemm/igemm_epilogue.h" #include "cutlass/gemm/igemm_global_tile.h" #include "cutlass/gemm/igemm_multiply_add.h" #include "cutlass/gemm/igemm_swizzle.h" #include "cutlass/reshape_tile.h" namespace cutlass { namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// template < /// The tile size for the GEMM KxNxM. typename OutputTile_, /// The output type. typename ScalarD_, /// Tile size for thread-level GEMM (K-by-N-by-M) typename ThreadGemmShape_> struct IgemmConfig : public GemmConfig< /// The scalar type for A. int8_t, /// The scalar type for B. int8_t, /// The scalar type for C. ScalarD_, /// The scalar type for D. ScalarD_, /// The tile size for the GEMM KxNxM. OutputTile_, /// The functor to do the math in the main loop. ThreadMultiplyAdd, int8_t, int8_t, int>, /// The number of scalars per LDG for A. 4, /// The number of scalars per STS for A. 4, /// The number of scalars per LDS for A. 16, /// The number of scalars per LDG for B. 4, /// The number of scalars per STS for B. 4, /// The number of scalars per LDS for B. 16, /// The number of scalars per LDG for C and STG for D. 1, /// The number of scalars per STS for D. 4, /// The number of scalars per LDS for D. 1, /// The number of stages in shared memory. 2, /// kResidueSeparate false, /// kResidueInPrologue false, /// kLaunchBounds false> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmConfig : public GemmConfig< /// The scalar type for A. int8_t, /// The scalar type for B. int8_t, /// The scalar type for C. int8_t, /// The scalar type for D. int8_t, /// The tile size for the GEMM KxNxM. OutputTile_, /// The functor to do the math in the main loop. ThreadMultiplyAdd, int8_t, int8_t, int>, /// The number of scalars per LDG for A. 4, /// The number of scalars per STS for A. 4, /// The number of scalars per LDS for A. 16, /// The number of scalars per LDG for B. 4, /// The number of scalars per STS for B. 4, /// The number of scalars per LDS for B. 16, /// The number of scalars per LDG for C and STG for D. 4, /// The number of scalars per STS for D. 4, /// The number of scalars per LDS for D. 4, /// The number of stages in shared memory. 2, /// If true, separate mainloop is instantiated from residue false, /// Compute residue in prolog? true, /// Launch bounds? false> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA { /// The base config. typedef GemmTileTraitsHelperA Base; /// The number of scalars per LDG/STS/LDS for A. static int const kScalarsPerStsA = 16; /// The traits class to build the iterator to load data from global memory for A^N. typedef IgemmGlobalTileTraits< GemmOperand::kA, // The layout. MatrixLayout::kColumnMajor, // The pointer is float const. int8_t const, // The tile has size KxM in GEMM's terminology. Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>, // The threads are distributed as warps x 32 (the traits may reorganize). Shape<1, ShapeCount::kCount, GemmConfig_::kWarpSize>, // The number of scalars per LDG (LDG.32 or LDG.128, etc). GemmConfig_::kScalarsPerLdgA> GlobalTileTraits; /// The global load iterator. typedef GemmGlobalIteratorAb GlobalLoadIterator; /// The traits class to build the iterator to store data to shared memory for A^N. typedef GemmSharedStoreTileAbTraits< // The pointer is float. int8_t, // The tile has size KxM in GEMM's terminology. Shape, // The threads are distributed as warps x 32 (the traits may reorganize). typename GlobalTileTraits::Threads, // The number of scalars per STS (STS.32 or STS.128, etc). kScalarsPerStsA> SharedStoreTileTraits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmTileTraitsHelperA { /// The layout. static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; /// The input scalar. typedef int8_t Scalar; /// The scalar stored in shared memory. typedef int8_t MultiplyAddScalar; /// The number of scalars per LDG/STS/LDS for A. static int const kScalarsPerStsA = 16; /// The traits class to build the iterator to load data from global memory for A^T. typedef IgemmGlobalTileTraits< GemmOperand::kA, // The layout. MatrixLayout::kRowMajor, // The pointer is float const. int8_t const, // The tile has size NxK in GEMM's terminology. Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>, // The threads are distributed as warps x 32 (the traits may reorganize). Shape<1, ShapeCount::kCount, GemmConfig_::kWarpSize>, // The number of scalars per LDG (LDG.32 or LDG.128, etc). GemmConfig_::kScalarsPerLdgA> GlobalTileTraits; /// The global load iterator. typedef IgemmGlobalIteratorAb GlobalLoadIterator; /// The traits class to build the iterator to store data to shared memory for A^N. typedef GemmSharedStoreWithSkewTileAbTraits< // The pointer is int8. int8_t, // The tile has size KxN in GEMM's terminology. Shape, // The threads are distributed as (threads / K) x K (the traits may reorganize). typename GlobalTileTraits::Threads, // The number of scalars per STS. kScalarsPerStsA, // The skew to avoid bank conflicts added in the tile W dimension. 16> SharedStoreTileTraits; /// The traits class to build the iterator to load from shared memory for A^N. typedef GemmSharedLoadTileATraits< // The pointer is float const. int8_t const, // The output tile size. typename GemmConfig_::OutputTile, // The number of warps. typename GemmConfig_::Warps, // The number of threads per warp. typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, // The shape of the FMA instruction. typename GemmConfig_::InstructionShape, // The number of stages. GemmConfig_::kStages, // The number of scalars per LDS. 16, // The skew. SharedStoreTileTraits::kSkew> SharedLoadTileTraits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB {}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmTileTraitsHelperB { /// The layout. static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; /// The input scalar. typedef int8_t Scalar; /// The scalar stored in shared memory. typedef int8_t MultiplyAddScalar; /// The number of scalars per LDG/STS/LDS for B. static int const kScalarsPerStsB = 16; /// The traits class to build the iterator to load data from global memory for B^T. typedef IgemmGlobalTileTraits< GemmOperand::kB, // The layout. MatrixLayout::kColumnMajor, // The pointer is float const. int8_t const, // The tile has size NxK in GEMM's terminology. Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>, // The threads are distributed as warps x 32 (the traits may reorganize). Shape<1, ShapeCount::kCount, GemmConfig_::kWarpSize>, // The number of scalars per LDG (LDG.32 or LDG.128, etc). GemmConfig_::kScalarsPerLdgB> GlobalTileTraits; /// The global load iterator. typedef IgemmGlobalIteratorAb GlobalLoadIterator; /// The traits class to build the iterator to store data to shared memory for B^N. typedef GemmSharedStoreWithSkewTileAbTraits< // The pointer is int8. int8_t, // The tile has size KxN in GEMM's terminology. Shape, // The threads are distributed as (threads / K) x K (the traits may reorganize). typename GlobalTileTraits::Threads, // The number of scalars per STS. kScalarsPerStsB, // The skew to avoid bank conflicts added in the tile W dimension. 16> SharedStoreTileTraits; /// The traits class to build the iterator to load from shared memory for B^N. typedef GemmSharedLoadTileBTraits< // The pointer is float const. int8_t const, // The output tile size. typename GemmConfig_::OutputTile, // The number of warps. typename GemmConfig_::Warps, // The number of threads per warp. typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, // The shape of the FMA instruction. typename GemmConfig_::InstructionShape, // The number of stages. GemmConfig_::kStages, // The number of scalars per LDS. 16, // The skew. SharedStoreTileTraits::kSkew> SharedLoadTileTraits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB { /// The base config. typedef GemmTileTraitsHelperB Base; /// The number of scalars per LDG/STS/LDS for B. static int const kScalarsPerStsB = 16; /// The traits class to build the iterator to load data from global memory for B^T. typedef IgemmGlobalTileTraits< GemmOperand::kB, // The layout. MatrixLayout::kRowMajor, // The pointer is float const. int8_t const, // The tile has size KxM in GEMM's terminology. Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>, // The threads are distributed as warps x 32 (the traits may reorganize). Shape<1, ShapeCount::kCount, GemmConfig_::kWarpSize>, // The number of scalars per LDG (LDG.32 or LDG.128, etc). GemmConfig_::kScalarsPerLdgB> GlobalTileTraits; /// The global load iterator. typedef GemmGlobalIteratorAb GlobalLoadIterator; /// The traits class to build the iterator to store data to shared memory for B^N. typedef GemmSharedStoreTileAbTraits< // The pointer is float. int8_t, // The tile has size KxM in GEMM's terminology. Shape, // The threads are distributed as warps x 32 (the traits may reorganize). typename GlobalTileTraits::Threads, // The number of scalars per STS (STS.32 or STS.128, etc). kScalarsPerStsB> SharedStoreTileTraits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmTransformerA {}; template struct IgemmTransformerA { typedef Copy Transformer; }; template struct IgemmTransformerA { typedef IgemmSwizzle Transformer; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmTransformerB {}; template struct IgemmTransformerB { typedef Copy Transformer; }; template struct IgemmTransformerB { typedef IgemmSwizzle Transformer; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < /// The layout for A. MatrixLayout::Kind kLayoutA_, /// The layout for B. MatrixLayout::Kind kLayoutB_, /// The output tile. typename OutputTile_, /// The output type. typename ScalarD_, /// The functor to do the math in the epilogue. typename EpilogueFunctor_, /// Tile size for thread-level GEMM (K-by-N-by-M) typename ThreadGemmShape_ = Shape<32, 8, 8>, /// The index. typename Index_ = int> struct IgemmTraitsHelper { /// The IGEMM config. typedef IgemmConfig GemmConfig; /// The GEMM config for A. typedef IgemmTileTraitsHelperA GemmTileTraitsHelperA; /// The GEMM config for B. typedef IgemmTileTraitsHelperB GemmTileTraitsHelperB; /// The iterator to load A from global memory. typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA; /// The default transformer for A. typedef typename IgemmTransformerA::Transformer GlobalTransformerA; /// The iterator to store A to shared memory. typedef TileStoreIterator SharedStoreIteratorA; /// The stream to load A from global memory to shared memory. typedef GlobalLoadStream GlobalLoadStreamA; /// The iterator to load B from global memory. typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB; // The default transformer for B. typedef typename IgemmTransformerB::Transformer GlobalTransformerB; /// The iterator to store B to shared memory. typedef TileStoreIterator SharedStoreIteratorB; /// The stream to load B from global memory to shared memory. typedef GlobalLoadStream GlobalLoadStreamB; /// The iterator to load A from shared memory. typedef TileLoadIterator SharedLoadIteratorA; /// The stream to load A from shared memory. typedef SharedLoadStream > SharedLoadStreamA; /// The iterator to load B from shared memory. typedef TileLoadIterator SharedLoadIteratorB; /// The stream to load B from shared memory. typedef SharedLoadStream > SharedLoadStreamB; /// The multiply-add functor. typedef typename GemmConfig::MultiplyAdd MultiplyAdd; /// The object to clear accumulators. typedef ClearAccumulators ClearAccumulators; /// The epilogue. typedef IgemmEpilogue > Epilogue; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct IgemmEpilogueScalar { typedef float Scalar; }; template <> struct IgemmEpilogueScalar { typedef int Scalar; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < /// The layout for A. MatrixLayout::Kind kLayoutA_, /// The layout for B. MatrixLayout::Kind kLayoutB_, /// The output tile. typename OutputTile_ = Shape<32, 128, 128>, /// The output type. typename ScalarD_ = int, /// The functor to do the math in the epilogue. typename EpilogueFunctor_ = LinearScaling::Scalar>, /// Tile size for thread-level GEMM (K-by-N-by-M) typename ThreadGemmShape_ = Shape<32, 8, 8>, /// The index. typename Index_ = int, /// The helper class. typename Helper_ = IgemmTraitsHelper > struct IgemmTraits : public GemmTraits< // The config. typename Helper_::GemmConfig, // The stream to load A from global memory to shared memory. typename Helper_::GlobalLoadStreamA, // The stream to load B from global memory to shared memory. typename Helper_::GlobalLoadStreamB, // The stream to load A from shared memory. typename Helper_::SharedLoadStreamA, // The stream to load B from shared memory. typename Helper_::SharedLoadStreamB, // The epilogue. typename Helper_::Epilogue, // The block swizzle to reorganize the grid. IdentityBlockSwizzle, // The index. Index_, // The tool used to clear accumulators. typename Helper_::ClearAccumulators> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace gemm } // namespace cutlass