| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							|  |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! \file
 | 
					
						
							|  |  |  |     \brief Template for GEMM performing a reduction over K partitions in parallel. | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/cutlass.h"
 | 
					
						
							|  |  |  | #include "cutlass/numeric_types.h"
 | 
					
						
							|  |  |  | #include "cutlass/arch/arch.h"
 | 
					
						
							|  |  |  | #include "cutlass/device_kernel.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/gemm.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/default_gemm_splitk_parallel.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/device/default_gemm_configuration.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/thread/conversion_op.h"
 | 
					
						
							|  |  |  | #include "cutlass/reduction/kernel/reduce_split_k.h"
 | 
					
						
							|  |  |  | #include "cutlass/reduction/thread/reduction_operators.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass { | 
					
						
							|  |  |  | namespace gemm { | 
					
						
							|  |  |  | namespace device { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /*! 
 | 
					
						
							|  |  |  |   Gemm device-level operator performing parallel reduction over the K partition. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |     /// Element type for A matrix operand
 | 
					
						
							|  |  |  |     typename ElementA_, | 
					
						
							|  |  |  |     /// Layout type for A matrix operand
 | 
					
						
							|  |  |  |     typename LayoutA_, | 
					
						
							|  |  |  |     /// Element type for B matrix operand
 | 
					
						
							|  |  |  |     typename ElementB_, | 
					
						
							|  |  |  |     /// Layout type for B matrix operand
 | 
					
						
							|  |  |  |     typename LayoutB_, | 
					
						
							|  |  |  |     /// Element type for C and D matrix operands
 | 
					
						
							|  |  |  |     typename ElementC_, | 
					
						
							|  |  |  |     /// Layout type for C and D matrix operands
 | 
					
						
							|  |  |  |     typename LayoutC_, | 
					
						
							|  |  |  |     /// Element type for internal accumulation
 | 
					
						
							|  |  |  |     typename ElementAccumulator_ = ElementC_, | 
					
						
							|  |  |  |     /// Operator class tag
 | 
					
						
							|  |  |  |     typename OperatorClass_ = arch::OpClassSimt, | 
					
						
							| 
									
										
										
										
											2021-02-26 22:58:26 +08:00
										 |  |  |     /// Tag indicating architecture to tune for.  This is the minimum SM that
 | 
					
						
							|  |  |  |       /// supports the intended feature. The device kernel can be built
 | 
					
						
							|  |  |  |       /// targeting any SM larger than this number.
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     typename ArchTag_ = arch::Sm70, | 
					
						
							|  |  |  |     /// Threadblock-level tile size (concept: GemmShape)
 | 
					
						
							|  |  |  |     typename ThreadblockShape_ = typename DefaultGemmConfiguration< | 
					
						
							|  |  |  |         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, | 
					
						
							|  |  |  |         ElementAccumulator_>::ThreadblockShape, | 
					
						
							|  |  |  |     /// Warp-level tile size (concept: GemmShape)
 | 
					
						
							|  |  |  |     typename WarpShape_ = typename DefaultGemmConfiguration< | 
					
						
							|  |  |  |         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, | 
					
						
							|  |  |  |         ElementAccumulator_>::WarpShape, | 
					
						
							|  |  |  |     /// Instruction-level tile size (concept: GemmShape)
 | 
					
						
							|  |  |  |     typename InstructionShape_ = typename DefaultGemmConfiguration< | 
					
						
							|  |  |  |         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, | 
					
						
							|  |  |  |         ElementAccumulator_>::InstructionShape, | 
					
						
							|  |  |  |     /// Epilogue output operator
 | 
					
						
							|  |  |  |     typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< | 
					
						
							|  |  |  |         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, | 
					
						
							|  |  |  |         ElementAccumulator_>::EpilogueOutputOp, | 
					
						
							|  |  |  |     /// Epilogue output operator
 | 
					
						
							|  |  |  |     typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert< | 
					
						
							|  |  |  |         ElementAccumulator_, | 
					
						
							|  |  |  |         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, | 
					
						
							|  |  |  |                                  ElementAccumulator_, | 
					
						
							|  |  |  |                                  ElementAccumulator_>::EpilogueOutputOp::kCount, | 
					
						
							|  |  |  |         ElementAccumulator_>, | 
					
						
							|  |  |  |     /// Reduction operator
 | 
					
						
							|  |  |  |     typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd< | 
					
						
							|  |  |  |         ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator, | 
					
						
							|  |  |  |         EpilogueOutputOp_::kCount>, | 
					
						
							|  |  |  |     /// Threadblock-level swizzling operator
 | 
					
						
							|  |  |  |     typename ThreadblockSwizzle_ = | 
					
						
							|  |  |  |         threadblock::GemmSplitKHorizontalThreadblockSwizzle, | 
					
						
							|  |  |  |     /// Number of stages used in the pipelined mainloop
 | 
					
						
							|  |  |  |     int Stages = | 
					
						
							|  |  |  |         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, | 
					
						
							|  |  |  |                                  ElementC_, ElementAccumulator_>::kStages, | 
					
						
							|  |  |  |     /// Access granularity of A matrix in units of elements
 | 
					
						
							|  |  |  |     int kAlignmentA = | 
					
						
							|  |  |  |         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, | 
					
						
							|  |  |  |                                  ElementC_, ElementAccumulator_>::kAlignmentA, | 
					
						
							|  |  |  |     /// Access granularity of B matrix in units of elements
 | 
					
						
							|  |  |  |     int kAlignmentB = | 
					
						
							|  |  |  |         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, | 
					
						
							|  |  |  |                                  ElementC_, ElementAccumulator_>::kAlignmentB, | 
					
						
							|  |  |  |     /// Operation performed by GEMM
 | 
					
						
							|  |  |  |     typename Operator_ = typename DefaultGemmConfiguration< | 
					
						
							|  |  |  |         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, | 
					
						
							|  |  |  |         ElementAccumulator_>::Operator> | 
					
						
							|  |  |  | class GemmSplitKParallel { | 
					
						
							|  |  |  |  public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementA = ElementA_; | 
					
						
							|  |  |  |   using LayoutA = LayoutA_; | 
					
						
							|  |  |  |   using ElementB = ElementB_; | 
					
						
							|  |  |  |   using LayoutB = LayoutB_; | 
					
						
							|  |  |  |   using ElementC = ElementC_; | 
					
						
							|  |  |  |   using LayoutC = LayoutC_; | 
					
						
							|  |  |  |   using ElementAccumulator = ElementAccumulator_; | 
					
						
							|  |  |  |   using OperatorClass = OperatorClass_; | 
					
						
							|  |  |  |   using ArchTag = ArchTag_; | 
					
						
							|  |  |  |   using ThreadblockShape = ThreadblockShape_; | 
					
						
							|  |  |  |   using WarpShape = WarpShape_; | 
					
						
							|  |  |  |   using InstructionShape = InstructionShape_; | 
					
						
							|  |  |  |   using ConvertScaledOp = ConvertScaledOp_; | 
					
						
							|  |  |  |   using EpilogueOutputOp = EpilogueOutputOp_; | 
					
						
							|  |  |  |   using ReductionOp = ReductionOp_; | 
					
						
							|  |  |  |   using ThreadblockSwizzle = ThreadblockSwizzle_; | 
					
						
							|  |  |  |   using Operator = Operator_; | 
					
						
							|  |  |  |   static int const kStages = Stages; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// GEMM kernel 
 | 
					
						
							|  |  |  |   using GemmKernel = typename kernel::DefaultGemmSplitKParallel< | 
					
						
							|  |  |  |     ElementA, | 
					
						
							|  |  |  |     LayoutA, | 
					
						
							|  |  |  |     kAlignmentA, | 
					
						
							|  |  |  |     ElementB, | 
					
						
							|  |  |  |     LayoutB, | 
					
						
							|  |  |  |     kAlignmentB, | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     LayoutC, | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     OperatorClass, | 
					
						
							|  |  |  |     ArchTag, | 
					
						
							|  |  |  |     ThreadblockShape, | 
					
						
							|  |  |  |     WarpShape, | 
					
						
							|  |  |  |     InstructionShape, | 
					
						
							|  |  |  |     ConvertScaledOp, | 
					
						
							|  |  |  |     ThreadblockSwizzle, | 
					
						
							|  |  |  |     kStages, | 
					
						
							|  |  |  |     Operator | 
					
						
							|  |  |  |   >::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Reduction kernel
 | 
					
						
							|  |  |  |   using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< | 
					
						
							|  |  |  |     cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, | 
					
						
							|  |  |  |     EpilogueOutputOp, | 
					
						
							|  |  |  |     ReductionOp | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Argument structure
 | 
					
						
							|  |  |  |   struct Arguments { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Data members
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     GemmCoord problem_size; | 
					
						
							|  |  |  |     TensorRef<ElementA const, LayoutA> ref_A; | 
					
						
							|  |  |  |     TensorRef<ElementB const, LayoutB> ref_B; | 
					
						
							|  |  |  |     TensorRef<ElementC const, LayoutC> ref_C; | 
					
						
							|  |  |  |     TensorRef<ElementC, LayoutC> ref_D; | 
					
						
							|  |  |  |     typename EpilogueOutputOp::Params epilogue; | 
					
						
							|  |  |  |     int split_k_slices; | 
					
						
							|  |  |  |     typename ConvertScaledOp::Params convert; | 
					
						
							|  |  |  |     typename ReductionOp::Params reduction; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Methods
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Default ctor
 | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |     Arguments() { } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Constructs an Arguments structure 
 | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |     Arguments( | 
					
						
							|  |  |  |       GemmCoord problem_size_, | 
					
						
							|  |  |  |       TensorRef<ElementA const, LayoutA> ref_A_, | 
					
						
							|  |  |  |       TensorRef<ElementB const, LayoutB> ref_B_, | 
					
						
							|  |  |  |       TensorRef<ElementC const, LayoutC> ref_C_, | 
					
						
							|  |  |  |       TensorRef<ElementC, LayoutC> ref_D_, | 
					
						
							|  |  |  |       typename EpilogueOutputOp::Params epilogue_ =  | 
					
						
							|  |  |  |         typename EpilogueOutputOp::Params(), | 
					
						
							|  |  |  |       int split_k_slices = 1, | 
					
						
							|  |  |  |       typename ConvertScaledOp::Params convert_ =  | 
					
						
							|  |  |  |         typename ConvertScaledOp::Params(), | 
					
						
							|  |  |  |       typename ReductionOp::Params reduction_ = | 
					
						
							|  |  |  |         typename ReductionOp::Params() | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |       problem_size(problem_size_), | 
					
						
							|  |  |  |       ref_A(ref_A_), | 
					
						
							|  |  |  |       ref_B(ref_B_), | 
					
						
							|  |  |  |       ref_C(ref_C_), | 
					
						
							|  |  |  |       ref_D(ref_D_), | 
					
						
							|  |  |  |       epilogue(epilogue_), | 
					
						
							|  |  |  |       split_k_slices(split_k_slices), | 
					
						
							|  |  |  |       convert(convert_), | 
					
						
							|  |  |  |       reduction(reduction_) { } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Kernel parameters object
 | 
					
						
							|  |  |  |   typename GemmKernel::Params gemm_params_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Reduction kernel parameters object
 | 
					
						
							|  |  |  |   typename ReductionKernel::Params reduction_params_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Constructs the GEMM.
 | 
					
						
							|  |  |  |   GemmSplitKParallel() { } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Determines whether the GEMM can execute the given problem.
 | 
					
						
							|  |  |  |   static Status can_implement(Arguments const &args) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // TODO
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return Status::kSuccess; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Gets the workspace size
 | 
					
						
							|  |  |  |   static size_t get_workspace_size(Arguments const &args) { | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     // Determine grid shape
 | 
					
						
							|  |  |  |     ThreadblockSwizzle threadblock_swizzle; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( | 
					
						
							|  |  |  |       args.problem_size,  | 
					
						
							|  |  |  |       {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, | 
					
						
							|  |  |  |       args.split_k_slices); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) * size_t(args.problem_size.n()) * grid_shape.k(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Initializes GEMM state from arguments.
 | 
					
						
							|  |  |  |   Status initialize(Arguments const &args, void *workspace) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Determine grid shape
 | 
					
						
							|  |  |  |     ThreadblockSwizzle threadblock_swizzle; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( | 
					
						
							|  |  |  |       args.problem_size,  | 
					
						
							|  |  |  |       {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, | 
					
						
							|  |  |  |       args.split_k_slices); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Define a reference to the workspace - this is an aligned region in device memory.
 | 
					
						
							|  |  |  |     if (!workspace) { | 
					
						
							|  |  |  |       return Status::kErrorWorkspaceNull; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     TensorRef<ElementAccumulator_, layout::RowMajor> ref_workspace( | 
					
						
							|  |  |  |       static_cast<ElementAccumulator_ *>(workspace),  | 
					
						
							|  |  |  |       args.problem_size.n()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Initialize the Params structure
 | 
					
						
							|  |  |  |     gemm_params_ = typename GemmKernel::Params{ | 
					
						
							|  |  |  |       args.problem_size, | 
					
						
							|  |  |  |       grid_shape, | 
					
						
							|  |  |  |       args.ref_A.non_const_ref(), | 
					
						
							|  |  |  |       args.ref_B.non_const_ref(), | 
					
						
							|  |  |  |       ref_workspace, | 
					
						
							|  |  |  |       args.convert, | 
					
						
							|  |  |  |       partition_stride | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     reduction_params_ = typename ReductionKernel::Params( | 
					
						
							|  |  |  |       args.problem_size.mn(), | 
					
						
							|  |  |  |       grid_shape.k(), | 
					
						
							|  |  |  |       partition_stride, | 
					
						
							|  |  |  |       ref_workspace, | 
					
						
							|  |  |  |       args.ref_D, | 
					
						
							|  |  |  |       args.ref_C.non_const_ref(), | 
					
						
							|  |  |  |       args.epilogue | 
					
						
							|  |  |  |     ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return Status::kSuccess; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Lightweight update given a subset of arguments
 | 
					
						
							|  |  |  |   Status update(Arguments const &args, void *workspace = nullptr) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (!workspace) { | 
					
						
							|  |  |  |       return Status::kErrorWorkspaceNull; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     gemm_params_.ref_A.reset(args.ref_A.data()); | 
					
						
							|  |  |  |     gemm_params_.ref_B.reset(args.ref_B.data()); | 
					
						
							|  |  |  |     gemm_params_.ref_D.reset(workspace);      | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     reduction_params_.ref_D.reset(args.ref_D.data()); | 
					
						
							|  |  |  |     reduction_params_.ref_C.reset(args.ref_C.data()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return Status::kSuccess; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Runs the kernel using initialized state.
 | 
					
						
							|  |  |  |   Status run(cudaStream_t stream = nullptr) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Launch GEMM kernel
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ThreadblockSwizzle threadblock_swizzle; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape); | 
					
						
							|  |  |  |     dim3 block(GemmKernel::kThreadCount, 1, 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cudaError_t result; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); | 
					
						
							|  |  |  |     if (smem_size >= (48 << 10)) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       result = cudaFuncSetAttribute( | 
					
						
							|  |  |  |         Kernel<GemmKernel>, | 
					
						
							|  |  |  |         cudaFuncAttributeMaxDynamicSharedMemorySize, | 
					
						
							|  |  |  |         smem_size); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (result != cudaSuccess) { | 
					
						
							|  |  |  |         return Status::kErrorInternal; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     result = cudaGetLastError(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (result != cudaSuccess) { | 
					
						
							|  |  |  |       return Status::kErrorInternal; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Launch reduction kernel
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     block = ReductionKernel::block_shape(); | 
					
						
							|  |  |  |     grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Kernel<ReductionKernel><<< grid, block, 0, stream >>>(reduction_params_); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     result = cudaGetLastError(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (result != cudaSuccess) { | 
					
						
							|  |  |  |       return Status::kErrorInternal; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Runs the kernel using initialized state.
 | 
					
						
							|  |  |  |   Status operator()(cudaStream_t stream = nullptr) { | 
					
						
							|  |  |  |     return run(stream); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Runs the kernel using initialized state.
 | 
					
						
							|  |  |  |   Status operator()( | 
					
						
							|  |  |  |     Arguments const &args,  | 
					
						
							|  |  |  |     void *workspace = nullptr,  | 
					
						
							|  |  |  |     cudaStream_t stream = nullptr) { | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     Status status = initialize(args, workspace); | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     if (status == Status::kSuccess) { | 
					
						
							|  |  |  |       status = run(stream); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return status; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Partial specialization for column-major output
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |     /// Element type for A matrix operand
 | 
					
						
							|  |  |  |     typename ElementA_, | 
					
						
							|  |  |  |     /// Layout type for A matrix operand
 | 
					
						
							|  |  |  |     typename LayoutA_, | 
					
						
							|  |  |  |     /// Element type for B matrix operand
 | 
					
						
							|  |  |  |     typename ElementB_, | 
					
						
							|  |  |  |     /// Layout type for B matrix operand
 | 
					
						
							|  |  |  |     typename LayoutB_, | 
					
						
							|  |  |  |     /// Element type for C and D matrix operands
 | 
					
						
							|  |  |  |     typename ElementC_, | 
					
						
							|  |  |  |     /// Element type for internal accumulation
 | 
					
						
							|  |  |  |     typename ElementAccumulator_, | 
					
						
							|  |  |  |     /// Operator class tag
 | 
					
						
							|  |  |  |     typename OperatorClass_, | 
					
						
							| 
									
										
										
										
											2021-02-26 22:58:26 +08:00
										 |  |  |     /// Tag indicating architecture to tune for.  This is the minimum SM that
 | 
					
						
							|  |  |  |       /// supports the intended feature. The device kernel can be built
 | 
					
						
							|  |  |  |       /// targeting any SM larger than this number.
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     typename ArchTag_, | 
					
						
							|  |  |  |     /// Threadblock-level tile size (concept: GemmShape)
 | 
					
						
							|  |  |  |     typename ThreadblockShape_, | 
					
						
							|  |  |  |     /// Warp-level tile size (concept: GemmShape)
 | 
					
						
							|  |  |  |     typename WarpShape_, | 
					
						
							|  |  |  |     /// Instruction-level tile size (concept: GemmShape)
 | 
					
						
							|  |  |  |     typename InstructionShape_, | 
					
						
							|  |  |  |     /// Epilogue output operator
 | 
					
						
							|  |  |  |     typename EpilogueOutputOp_, | 
					
						
							|  |  |  |     /// Epilogue output operator
 | 
					
						
							|  |  |  |     typename ConvertScaledOp_, | 
					
						
							|  |  |  |     /// Reduction operator
 | 
					
						
							|  |  |  |     typename ReductionOp_, | 
					
						
							|  |  |  |     /// Threadblock-level swizzling operator
 | 
					
						
							|  |  |  |     typename ThreadblockSwizzle_, | 
					
						
							|  |  |  |     /// Number of stages used in the pipelined mainloop
 | 
					
						
							|  |  |  |     int Stages, int kAlignmentA, int kAlignmentB, | 
					
						
							|  |  |  |     /// Operation performed by GEMM
 | 
					
						
							|  |  |  |     typename Operator_> | 
					
						
							|  |  |  | class GemmSplitKParallel<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, | 
					
						
							|  |  |  |                          layout::ColumnMajor, ElementAccumulator_, | 
					
						
							|  |  |  |                          OperatorClass_, ArchTag_, ThreadblockShape_, | 
					
						
							|  |  |  |                          WarpShape_, InstructionShape_, EpilogueOutputOp_, | 
					
						
							|  |  |  |                          ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, | 
					
						
							|  |  |  |                          Stages, kAlignmentA, kAlignmentB, Operator_> { | 
					
						
							|  |  |  |  public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementA = ElementA_; | 
					
						
							|  |  |  |   using LayoutA = LayoutA_; | 
					
						
							|  |  |  |   using ElementB = ElementB_; | 
					
						
							|  |  |  |   using LayoutB = LayoutB_; | 
					
						
							|  |  |  |   using ElementC = ElementC_; | 
					
						
							|  |  |  |   using LayoutC = layout::ColumnMajor; | 
					
						
							|  |  |  |   using ElementAccumulator = ElementAccumulator_; | 
					
						
							|  |  |  |   using OperatorClass = OperatorClass_; | 
					
						
							|  |  |  |   using ArchTag = ArchTag_; | 
					
						
							|  |  |  |   using ThreadblockShape = ThreadblockShape_; | 
					
						
							|  |  |  |   using WarpShape = WarpShape_; | 
					
						
							|  |  |  |   using InstructionShape = InstructionShape_; | 
					
						
							|  |  |  |   using ConvertScaledOp = ConvertScaledOp_; | 
					
						
							|  |  |  |   using EpilogueOutputOp = EpilogueOutputOp_; | 
					
						
							|  |  |  |   using ReductionOp = ReductionOp_; | 
					
						
							|  |  |  |   using ThreadblockSwizzle = ThreadblockSwizzle_; | 
					
						
							|  |  |  |   using Operator = Operator_; | 
					
						
							|  |  |  |   static int const kStages = Stages; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using UnderlyingOperator = GemmSplitKParallel<  | 
					
						
							|  |  |  |     ElementB, | 
					
						
							|  |  |  |     typename layout::LayoutTranspose<LayoutB>::type, | 
					
						
							|  |  |  |     ElementA, | 
					
						
							|  |  |  |     typename layout::LayoutTranspose<LayoutA>::type, | 
					
						
							|  |  |  |     ElementC, | 
					
						
							|  |  |  |     layout::RowMajor,     | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     OperatorClass, | 
					
						
							|  |  |  |     ArchTag, | 
					
						
							|  |  |  |     ThreadblockShape, | 
					
						
							|  |  |  |     WarpShape, | 
					
						
							|  |  |  |     InstructionShape, | 
					
						
							|  |  |  |     EpilogueOutputOp, | 
					
						
							|  |  |  |     ConvertScaledOp, | 
					
						
							|  |  |  |     ReductionOp, | 
					
						
							|  |  |  |     ThreadblockSwizzle, | 
					
						
							|  |  |  |     Stages, | 
					
						
							|  |  |  |     kAlignmentA, | 
					
						
							|  |  |  |     kAlignmentB, | 
					
						
							|  |  |  |     Operator | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using UnderlyingArguments = typename UnderlyingOperator::Arguments; | 
					
						
							|  |  |  |   using GemmKernel = typename UnderlyingOperator::GemmKernel; | 
					
						
							|  |  |  |   using ReductionKernel = typename UnderlyingOperator::ReductionKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Argument structure
 | 
					
						
							|  |  |  |   struct Arguments { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Data members
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     GemmCoord problem_size; | 
					
						
							|  |  |  |     TensorRef<ElementA const, LayoutA> ref_A; | 
					
						
							|  |  |  |     TensorRef<ElementB const, LayoutB> ref_B; | 
					
						
							|  |  |  |     TensorRef<ElementC const, LayoutC> ref_C; | 
					
						
							|  |  |  |     TensorRef<ElementC, LayoutC> ref_D; | 
					
						
							|  |  |  |     typename EpilogueOutputOp::Params epilogue; | 
					
						
							|  |  |  |     int split_k_slices; | 
					
						
							|  |  |  |     typename ConvertScaledOp::Params convert; | 
					
						
							|  |  |  |     typename ReductionOp::Params reduction; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Methods
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Default ctor
 | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |     Arguments() { } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Constructs an Arguments structure 
 | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |     Arguments( | 
					
						
							|  |  |  |       GemmCoord problem_size_, | 
					
						
							|  |  |  |       TensorRef<ElementA const, LayoutA> ref_A_, | 
					
						
							|  |  |  |       TensorRef<ElementB const, LayoutB> ref_B_, | 
					
						
							|  |  |  |       TensorRef<ElementC const, LayoutC> ref_C_, | 
					
						
							|  |  |  |       TensorRef<ElementC, LayoutC> ref_D_, | 
					
						
							|  |  |  |       typename EpilogueOutputOp::Params epilogue_ =  | 
					
						
							|  |  |  |         typename EpilogueOutputOp::Params(), | 
					
						
							|  |  |  |       int split_k_slices = 1, | 
					
						
							|  |  |  |       typename ConvertScaledOp::Params convert_ =  | 
					
						
							|  |  |  |         typename ConvertScaledOp::Params(), | 
					
						
							|  |  |  |       typename ReductionOp::Params reduction_ = | 
					
						
							|  |  |  |         typename ReductionOp::Params() | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |       problem_size(problem_size_), | 
					
						
							|  |  |  |       ref_A(ref_A_), | 
					
						
							|  |  |  |       ref_B(ref_B_), | 
					
						
							|  |  |  |       ref_C(ref_C_), | 
					
						
							|  |  |  |       ref_D(ref_D_), | 
					
						
							|  |  |  |       epilogue(epilogue_), | 
					
						
							|  |  |  |       split_k_slices(split_k_slices), | 
					
						
							|  |  |  |       convert(convert_), | 
					
						
							|  |  |  |       reduction(reduction_) { } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Kernel parameters object
 | 
					
						
							|  |  |  |   UnderlyingOperator underlying_operator_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Constructs the GEMM.
 | 
					
						
							|  |  |  |   GemmSplitKParallel() { } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Helper to construct a transposed equivalent for the underying GEMM operator
 | 
					
						
							|  |  |  |   static UnderlyingArguments to_underlying_arguments(Arguments const &args) { | 
					
						
							|  |  |  |     return UnderlyingArguments( | 
					
						
							|  |  |  |       {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, | 
					
						
							|  |  |  |       {args.ref_B.data(), args.ref_B.stride(0)}, | 
					
						
							|  |  |  |       {args.ref_A.data(), args.ref_A.stride(0)}, | 
					
						
							|  |  |  |       {args.ref_C.data(), args.ref_C.stride(0)}, | 
					
						
							|  |  |  |       {args.ref_D.data(), args.ref_D.stride(0)}, | 
					
						
							|  |  |  |       args.epilogue, | 
					
						
							|  |  |  |       args.split_k_slices, | 
					
						
							|  |  |  |       args.convert, | 
					
						
							|  |  |  |       args.reduction | 
					
						
							|  |  |  |     ); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Determines whether the GEMM can execute the given problem.
 | 
					
						
							|  |  |  |   static Status can_implement(Arguments const &args) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return UnderlyingOperator::can_implement(to_underlying_arguments(args)); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Gets the workspace size
 | 
					
						
							|  |  |  |   static size_t get_workspace_size(Arguments const &args) { | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Initializes GEMM state from arguments.
 | 
					
						
							|  |  |  |   Status initialize(Arguments const &args, void *workspace) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return underlying_operator_.initialize(to_underlying_arguments(args), workspace); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Lightweight update given a subset of arguments
 | 
					
						
							|  |  |  |   Status update(Arguments const &args, void *workspace = nullptr) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return underlying_operator_.update(to_underlying_arguments(args), workspace); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Runs the kernel using initialized state.
 | 
					
						
							|  |  |  |   Status run(cudaStream_t stream = nullptr) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return underlying_operator_.run(stream); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Runs the kernel using initialized state.
 | 
					
						
							|  |  |  |   Status operator()(cudaStream_t stream = nullptr) { | 
					
						
							|  |  |  |     return run(stream); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Runs the kernel using initialized state.
 | 
					
						
							|  |  |  |   Status operator()( | 
					
						
							|  |  |  |     Arguments const &args,  | 
					
						
							|  |  |  |     void *workspace = nullptr,  | 
					
						
							|  |  |  |     cudaStream_t stream = nullptr) { | 
					
						
							|  |  |  |      | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |     Status status = initialize(args, workspace, stream); | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |      | 
					
						
							|  |  |  |     if (status == Status::kSuccess) { | 
					
						
							|  |  |  |       status = run(stream); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return status; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace device
 | 
					
						
							|  |  |  | } // namespace gemm
 | 
					
						
							|  |  |  | } // namespace cutlass
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ////////////////////////////////////////////////////////////////////////////////
 |