159 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
		
		
			
		
	
	
			159 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
|   | /*! \file
 | ||
|  |     \brief Cutlass provides helper template functions to figure out the right | ||
|  |    datastructures to instanciate to run a GEMM with various parameters (see | ||
|  |    `cutlass/gemm/threadblock/default_mma.h`). However, due to template | ||
|  |    instanciation priority rules, it will only create an MmaMultiStage with | ||
|  |    kStages=3 (otherwise creates an MmePipelined - which is not compatible with | ||
|  |    FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, | ||
|  |    so we just copy-pasted some code from `default_mma.h` and | ||
|  |    `default_mma_core.h` files and wrapped this template to allow our usecase. | ||
|  | 
 | ||
|  |     This is really only for the FastF32 case - aka using TensorCores with fp32. | ||
|  | */ | ||
|  | 
 | ||
|  | #include "cutlass/gemm/threadblock/default_mma.h"
 | ||
|  | #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
 | ||
|  | #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
 | ||
|  | #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
 | ||
|  | #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
 | ||
|  | 
 | ||
|  | namespace cutlass { | ||
|  | namespace gemm { | ||
|  | namespace threadblock { | ||
|  | 
 | ||
|  | template < | ||
|  |     /// Element type for A matrix operand
 | ||
|  |     typename ElementA, | ||
|  |     /// Layout type for A matrix operand
 | ||
|  |     typename LayoutA, | ||
|  |     /// Access granularity of A matrix in units of elements
 | ||
|  |     int kAlignmentA, | ||
|  |     /// Element type for B matrix operand
 | ||
|  |     typename ElementB, | ||
|  |     /// Layout type for B matrix operand
 | ||
|  |     typename LayoutB, | ||
|  |     /// Access granularity of B matrix in units of elements
 | ||
|  |     int kAlignmentB, | ||
|  |     /// Element type for internal accumulation
 | ||
|  |     typename ElementAccumulator, | ||
|  |     /// Layout type for C and D matrix operand
 | ||
|  |     typename LayoutC, | ||
|  |     /// Operator class tag
 | ||
|  |     typename OperatorClass, | ||
|  |     /// Tag indicating architecture to tune for
 | ||
|  |     typename ArchTag, | ||
|  |     /// Threadblock-level tile size (concept: GemmShape)
 | ||
|  |     typename ThreadblockShape, | ||
|  |     /// Warp-level tile size (concept: GemmShape)
 | ||
|  |     typename WarpShape, | ||
|  |     /// Instruction-level tile size (concept: GemmShape)
 | ||
|  |     typename InstructionShape, | ||
|  |     /// Number of stages used in the pipelined mainloop
 | ||
|  |     int Stages, | ||
|  |     /// Operation perfomed by GEMM
 | ||
|  |     typename Operator, | ||
|  |     typename Enable_ = void> | ||
|  | struct FindDefaultMma { | ||
|  |   static constexpr bool AccumulatorsInRowMajor = false; | ||
|  |   static constexpr SharedMemoryClearOption SharedMemoryClear = | ||
|  |       SharedMemoryClearOption::kNone; | ||
|  |   using DefaultMma = cutlass::gemm::threadblock::DefaultMma< | ||
|  |       ElementA, | ||
|  |       LayoutA, | ||
|  |       kAlignmentA, | ||
|  |       ElementB, | ||
|  |       LayoutB, | ||
|  |       kAlignmentB, | ||
|  |       ElementAccumulator, | ||
|  |       LayoutC, | ||
|  |       OperatorClass, | ||
|  |       ArchTag, | ||
|  |       ThreadblockShape, | ||
|  |       WarpShape, | ||
|  |       InstructionShape, | ||
|  |       Stages, | ||
|  |       Operator, | ||
|  |       AccumulatorsInRowMajor, | ||
|  |       SharedMemoryClear>; | ||
|  | }; | ||
|  | 
 | ||
|  | /// Specialization for sm80 / FastF32 / multistage with kStages=2
 | ||
|  | template < | ||
|  |     typename ElementA_, | ||
|  |     /// Layout type for A matrix operand
 | ||
|  |     typename LayoutA_, | ||
|  |     /// Access granularity of A matrix in units of elements
 | ||
|  |     int kAlignmentA, | ||
|  |     typename ElementB_, | ||
|  |     /// Layout type for B matrix operand
 | ||
|  |     typename LayoutB_, | ||
|  |     /// Access granularity of B matrix in units of elements
 | ||
|  |     int kAlignmentB, | ||
|  |     typename ElementAccumulator, | ||
|  |     /// Threadblock-level tile size (concept: GemmShape)
 | ||
|  |     typename ThreadblockShape, | ||
|  |     /// Warp-level tile size (concept: GemmShape)
 | ||
|  |     typename WarpShape, | ||
|  |     /// Instruction-level tile size (concept: GemmShape)
 | ||
|  |     typename InstructionShape, | ||
|  |     int kStages, | ||
|  |     typename Operator> | ||
|  | struct FindDefaultMma< | ||
|  |     ElementA_, | ||
|  |     LayoutA_, | ||
|  |     kAlignmentA, | ||
|  |     ElementB_, | ||
|  |     LayoutB_, | ||
|  |     kAlignmentB, | ||
|  |     ElementAccumulator, | ||
|  |     layout::RowMajor, | ||
|  |     arch::OpClassTensorOp, | ||
|  |     arch::Sm80, | ||
|  |     ThreadblockShape, | ||
|  |     WarpShape, | ||
|  |     InstructionShape, | ||
|  |     kStages, | ||
|  |     Operator, | ||
|  |     typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { | ||
|  |   using LayoutC = layout::RowMajor; | ||
|  |   using OperatorClass = arch::OpClassTensorOp; | ||
|  |   using ArchTag = arch::Sm80; | ||
|  | 
 | ||
|  |   using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< | ||
|  |       ElementA_, | ||
|  |       LayoutA_, | ||
|  |       kAlignmentA, | ||
|  |       ElementB_, | ||
|  |       LayoutB_, | ||
|  |       kAlignmentB, | ||
|  |       ElementAccumulator, | ||
|  |       LayoutC, | ||
|  |       OperatorClass, | ||
|  |       ArchTag, | ||
|  |       ThreadblockShape, | ||
|  |       WarpShape, | ||
|  |       InstructionShape, | ||
|  |       3, | ||
|  |       Operator>; | ||
|  |   struct DefaultMma : DefaultMma_ { | ||
|  |     using MmaCore_ = typename DefaultMma_::MmaCore; | ||
|  |     // Define the threadblock-scoped multistage matrix multiply
 | ||
|  |     using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< | ||
|  |         typename MmaCore_::Shape, | ||
|  |         typename DefaultMma_::IteratorA, | ||
|  |         typename MmaCore_::SmemIteratorA, | ||
|  |         MmaCore_::kCacheOpA, | ||
|  |         typename DefaultMma_::IteratorB, | ||
|  |         typename MmaCore_::SmemIteratorB, | ||
|  |         MmaCore_::kCacheOpB, | ||
|  |         ElementAccumulator, | ||
|  |         LayoutC, | ||
|  |         typename MmaCore_::MmaPolicy, | ||
|  |         kStages>; | ||
|  |   }; | ||
|  | }; | ||
|  | 
 | ||
|  | } // namespace threadblock
 | ||
|  | } // namespace gemm
 | ||
|  | } // namespace cutlass
 |