 3c995c7606
			
		
	
	
		3c995c7606
		
			
		
	
	
	
	
		
			
			* Fix MHA kernel Summary: ATT Test Plan: Reviewers: Subscribers: Tasks: Tags: * Extend DualGemm to support batched mode (#5) Following the GemmUniversalMode::kBatched implementation, batched mode is added to the DualGemm (under examples/45_dual_gemm). DualGemmMode::kBatched and SplitKSerial are not compatible: Status::kErrorInvalidProblem is returned if both are set. * Decouple LayoutB0 and LayoutB1 in DualGemm The DualGemm template assumed the same layout, LayoutB, for both right operand matrices B0 and B1. This is problematic if the layout of the two matrices is different. In particular, this may be the case when one of the matrices is row-major, while the other is a (column) vector that has to be broadcasted in column-major with zero stride (e.g., as {B1.device_data(), 0}) for the DualGemm implementation to be able to process B0 and B1 simultaneously. In this commit, LayoutB0 and LayoutB1 are decoupled throughout the DualGemm code (device, kernel, and mma). Additionally, the batch strides of B0 and B1 are also decoupled to accommodate the column vector B1 case described above. * Remove comment as no longer relevant * Revert Fix MHA kernel --------- Co-authored-by: mikeiovine <mikeiovine@fb.com>
		
			
				
	
	
		
			500 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			500 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| /*! \file
 | |
|     \brief Performs a dual gemm in one fused kernel:
 | |
| ```
 | |
| D0 = epilogue0(X @ B0, C0)
 | |
| D1 = epilogue1(X @ B1, C1)
 | |
| D2 = element_wise(D0, D1)
 | |
| ```
 | |
| */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/numeric_types.h"
 | |
| #include "cutlass/arch/arch.h"
 | |
| #include "cutlass/device_kernel.h"
 | |
| 
 | |
| #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
 | |
| 
 | |
| #include "cutlass/gemm/device/default_gemm_configuration.h"
 | |
| #include "cutlass/gemm/threadblock/default_mma.h"
 | |
| #include "cutlass/epilogue/thread/linear_combination_relu.h"
 | |
| #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
 | |
| 
 | |
| #include "../kernel/dual_gemm.h"
 | |
| #include "../dual_gemm_common.h"
 | |
| 
 | |
| ////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| namespace cutlass {
 | |
| namespace gemm {
 | |
| namespace device {
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| template <
 | |
|     /// Element type for A matrix operand
 | |
|     typename ElementA_,
 | |
|     /// Layout type for A matrix operand
 | |
|     typename LayoutA_,
 | |
|     /// Element type for B matrix operand
 | |
|     typename ElementB_,
 | |
|     /// Layout type for B0 matrix operand
 | |
|     typename LayoutB0_,
 | |
|     /// Layout type for B1 matrix operand
 | |
|     typename LayoutB1_,
 | |
|     /// Element type for C and D matrix operands
 | |
|     typename ElementC_,
 | |
|     /// Layout type for C and D matrix operands
 | |
|     typename LayoutC_,
 | |
|     /// Element type for internal accumulation
 | |
|     typename ElementAccumulator_,
 | |
|     /// Operator class tag
 | |
|     typename OperatorClass_,
 | |
|     /// Tag indicating architecture to tune for
 | |
|     typename ArchTag_,
 | |
|     /// Threadblock-level tile size (concept: GemmShape)
 | |
|     typename ThreadblockShape_,
 | |
|     /// Warp-level tile size (concept: GemmShape)
 | |
|     typename WarpShape_,
 | |
|     /// Instruction-level tile size (concept: GemmShape)
 | |
|     typename InstructionShape_,
 | |
|     /// Epilogue output operator
 | |
|     typename EpilogueOutputOp0_,
 | |
|     typename EpilogueOutputOp1_,
 | |
|     typename EpilogueOutputOp2_,
 | |
|     /// Threadblock-level swizzling operator
 | |
|     typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
 | |
|     /// Number of stages used in the pipelined mainloop
 | |
|     int Stages =
 | |
|         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
 | |
|                                  ElementC_, ElementAccumulator_>::kStages,
 | |
|     bool StoreD0 = true,
 | |
|     bool StoreD1 = true,
 | |
|     /// If true, kernel supports split-K with serial reduction
 | |
|     bool SplitKSerial = false,
 | |
|     /// Access granularity of A matrix in units of elements
 | |
|     int AlignmentA =
 | |
|         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
 | |
|                                  ElementC_, ElementAccumulator_>::kAlignmentA,
 | |
|     /// Access granularity of B matrix in units of elements
 | |
|     int AlignmentB =
 | |
|         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
 | |
|                                  ElementC_, ElementAccumulator_>::kAlignmentB,
 | |
|     /// Operation performed by GEMM
 | |
|     typename Operator_ = typename DefaultGemmConfiguration<
 | |
|         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
 | |
|         ElementAccumulator_>::Operator>
 | |
| class DualGemm {
 | |
|  public:
 | |
| 
 | |
|   using ElementA = ElementA_;
 | |
|   using LayoutA = LayoutA_;
 | |
|   using TensorRefA = TensorRef<ElementA const, LayoutA>;
 | |
|   using ElementB = ElementB_;
 | |
|   using LayoutB0 = LayoutB0_;
 | |
|   using LayoutB1 = LayoutB1_;
 | |
|   using TensorRefB0 = TensorRef<ElementB const, LayoutB0>;
 | |
|   using TensorRefB1 = TensorRef<ElementB const, LayoutB1>;
 | |
|   using ElementC = ElementC_;
 | |
|   using LayoutC = LayoutC_;
 | |
|   using TensorRefC = TensorRef<ElementC const, LayoutC>;
 | |
|   using TensorRefD = TensorRef<ElementC, LayoutC>;
 | |
|   using ElementAccumulator = ElementAccumulator_;
 | |
|   using OperatorClass = OperatorClass_;
 | |
|   using ArchTag = ArchTag_;
 | |
|   using ThreadblockShape = ThreadblockShape_;
 | |
|   using WarpShape = WarpShape_;
 | |
|   using InstructionShape = InstructionShape_;
 | |
|   using EpilogueOutputOp0 = EpilogueOutputOp0_;
 | |
|   using EpilogueOutputOp1 = EpilogueOutputOp1_;
 | |
|   using EpilogueOutputOp2 = EpilogueOutputOp2_;
 | |
|   using ThreadblockSwizzle = ThreadblockSwizzle_;
 | |
|   using Operator = Operator_;
 | |
|   static int const kStages = Stages;
 | |
|   static int const kAlignmentA = AlignmentA;
 | |
|   static int const kAlignmentB = AlignmentB;
 | |
|   static int const kAlignmentC = EpilogueOutputOp1::kCount;
 | |
|   static bool const kSplitKSerial = SplitKSerial;
 | |
|   static bool constexpr kStoreD0 = StoreD0;
 | |
|   static bool constexpr kStoreD1 = StoreD1;
 | |
|   static ComplexTransform const kTransformA = ComplexTransform::kNone;
 | |
|   static ComplexTransform const kTransformB = ComplexTransform::kNone;
 | |
| 
 | |
|   using LayoutScaleBias = layout::RowMajor;
 | |
|   /// Define the kernel
 | |
|   /// Define the threadblock-scoped matrix multiply-accumulate
 | |
|   static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented");
 | |
|   static_assert(kStages >= 3, "Only multistage is implemented");
 | |
|   using Mma0 = typename cutlass::gemm::threadblock::DefaultMma<
 | |
|       ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB,
 | |
|       ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
 | |
|       ThreadblockShape, WarpShape, 
 | |
|       InstructionShape, Stages, Operator>::ThreadblockMma;
 | |
|   using Mma1 = typename cutlass::gemm::threadblock::DefaultMma<
 | |
|       ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB,
 | |
|       ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
 | |
|       ThreadblockShape, WarpShape, 
 | |
|       InstructionShape, Stages, Operator>::ThreadblockMma;
 | |
|   using DualMma = threadblock::DualMmaMultistage<
 | |
|     typename Mma0::Shape,
 | |
|     typename Mma0::IteratorA,
 | |
|     typename Mma0::SmemIteratorA,
 | |
|     Mma0::kCacheOpA,
 | |
|     typename Mma0::IteratorB,
 | |
|     typename Mma0::SmemIteratorB,
 | |
|     Mma0::kCacheOpB,
 | |
|     typename Mma1::IteratorB,
 | |
|     typename Mma1::SmemIteratorB,
 | |
|     typename Mma0::ElementC,
 | |
|     typename Mma0::LayoutC,
 | |
|     typename Mma0::Policy,
 | |
|     typename Mma1::Policy,
 | |
|     Mma0::kStages,
 | |
|     SharedMemoryClearOption::kNone
 | |
|   >;
 | |
| 
 | |
|   static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
 | |
| 
 | |
|   /// Define the epilogue
 | |
|   using Epilogue0 =
 | |
|       typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
 | |
|           ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0,
 | |
|           EpilogueOutputOp0::kCount>::Epilogue;
 | |
|   using Epilogue1 =
 | |
|       typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
 | |
|           ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1,
 | |
|           EpilogueOutputOp1::kCount>::Epilogue;
 | |
| 
 | |
|   /// Define the kernel-level GEMM operator.
 | |
|   using DualGemmKernel = kernel::DualGemm<
 | |
|     DualMma,
 | |
|     Epilogue0, Epilogue1, EpilogueOutputOp2,
 | |
|     ThreadblockSwizzle, kSplitKSerial,
 | |
|     kStoreD0, kStoreD1>;
 | |
| 
 | |
|   /// Argument structure
 | |
|   struct Arguments {
 | |
| 
 | |
|     //
 | |
|     // Data members
 | |
|     //
 | |
| 
 | |
|     DualGemmMode mode;
 | |
|     GemmCoord problem_size;
 | |
|     TensorRef<ElementA const, LayoutA> ref_A0;
 | |
|     TensorRef<ElementB const, LayoutB0> ref_B0;
 | |
|     TensorRef<ElementC const, LayoutC> ref_C0;
 | |
|     TensorRef<ElementC, LayoutC> ref_D0;
 | |
|     TensorRef<ElementB const, LayoutB1> ref_B1;
 | |
|     TensorRef<ElementC const, LayoutC> ref_C1;
 | |
|     TensorRef<ElementC, LayoutC> ref_D1;
 | |
|     TensorRef<ElementC, LayoutC> ref_D2;
 | |
|     typename EpilogueOutputOp0::Params epilogue0;
 | |
|     typename EpilogueOutputOp1::Params epilogue1;
 | |
|     typename EpilogueOutputOp2::Params epilogue2;
 | |
|     int split_k_slices;
 | |
| 
 | |
|     int batch_count;
 | |
|     int64_t batch_stride_A;
 | |
|     int64_t batch_stride_B0;
 | |
|     int64_t batch_stride_B1;
 | |
|     int64_t batch_stride_C;
 | |
|     int64_t batch_stride_D;
 | |
| 
 | |
|     //
 | |
|     // Methods
 | |
|     //
 | |
| 
 | |
|     /// Default ctor
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
 | |
| 
 | |
|     }
 | |
| 
 | |
|     /// Constructs an Arguments structure 
 | |
|     CUTLASS_HOST_DEVICE
 | |
|     Arguments(
 | |
|       DualGemmMode mode,
 | |
|       GemmCoord problem_size_,
 | |
|       TensorRef<ElementA const, LayoutA> ref_A0_,
 | |
|       TensorRef<ElementB const, LayoutB0> ref_B0_,
 | |
|       TensorRef<ElementC const, LayoutC> ref_C0_,
 | |
|       TensorRef<ElementC, LayoutC> ref_D0_,
 | |
|       TensorRef<ElementB const, LayoutB1> ref_B1_,
 | |
|       TensorRef<ElementC const, LayoutC> ref_C1_,
 | |
|       TensorRef<ElementC, LayoutC> ref_D1_,
 | |
|       TensorRef<ElementC, LayoutC> ref_D2_,
 | |
|       typename EpilogueOutputOp0::Params epilogue0_ =
 | |
|         typename EpilogueOutputOp0::Params(),
 | |
|       typename EpilogueOutputOp1::Params epilogue1_ =
 | |
|         typename EpilogueOutputOp1::Params(),
 | |
|       typename EpilogueOutputOp2::Params epilogue2_ =
 | |
|         typename EpilogueOutputOp2::Params(),
 | |
|       int split_k_slices_ = 1,
 | |
|       int batch_count = 1,
 | |
|       int64_t batch_stride_A = 0,
 | |
|       int64_t batch_stride_B0 = 0,
 | |
|       int64_t batch_stride_B1 = 0,
 | |
|       int64_t batch_stride_C = 0,
 | |
|       int64_t batch_stride_D = 0
 | |
|     ):
 | |
|       mode(mode),
 | |
|       problem_size(problem_size_),
 | |
|       ref_A0(ref_A0_),
 | |
|       ref_B0(ref_B0_),
 | |
|       ref_C0(ref_C0_),
 | |
|       ref_D0(ref_D0_),
 | |
|       ref_B1(ref_B1_),
 | |
|       ref_C1(ref_C1_),
 | |
|       ref_D1(ref_D1_),
 | |
|       ref_D2(ref_D2_),
 | |
|       epilogue0(epilogue0_),
 | |
|       epilogue1(epilogue1_),
 | |
|       epilogue2(epilogue2_),
 | |
|       split_k_slices(split_k_slices_),
 | |
|       batch_count(batch_count),
 | |
|       batch_stride_A(batch_stride_A),
 | |
|       batch_stride_B0(batch_stride_B0),
 | |
|       batch_stride_B1(batch_stride_B1),
 | |
|       batch_stride_C(batch_stride_C),
 | |
|       batch_stride_D(batch_stride_D) {
 | |
| 
 | |
|     }
 | |
|   };
 | |
| 
 | |
| private:
 | |
| 
 | |
|   /// Kernel parameters object
 | |
|   typename DualGemmKernel::Params params_;
 | |
| 
 | |
| public:
 | |
| 
 | |
|   /// Constructs the GEMM.
 | |
|   DualGemm() = default;
 | |
| 
 | |
|   /// Determines whether the GEMM can execute the given problem.
 | |
|   static Status can_implement(Arguments const &args) {
 | |
| 
 | |
|     if (args.mode == DualGemmMode::kBatched && kSplitKSerial) {
 | |
|       return Status::kErrorInvalidProblem;
 | |
|     }
 | |
|     if (!kSplitKSerial && args.split_k_slices > 1) {
 | |
|       return Status::kErrorInvalidProblem;
 | |
|     }
 | |
|     if (kStoreD0 != (args.ref_D0.data() != nullptr)) {
 | |
|       return Status::kErrorInternal;
 | |
|     }
 | |
|     if (kStoreD1 != (args.ref_D1.data() != nullptr)) {
 | |
|       return Status::kErrorInternal;
 | |
|     }
 | |
| 
 | |
|     Status status = DualGemmKernel::can_implement(
 | |
|       args.problem_size,
 | |
|       args.ref_A0.non_const_ref(),
 | |
|       args.ref_B0.non_const_ref(),
 | |
|       args.ref_C0.non_const_ref(),
 | |
|       args.ref_D0,
 | |
|       args.ref_B1.non_const_ref(),
 | |
|       args.ref_C1.non_const_ref(),
 | |
|       args.ref_D1,
 | |
|       args.ref_D2
 | |
|     );
 | |
| 
 | |
|     if (status != Status::kSuccess) {
 | |
|       return status;
 | |
|     }
 | |
| 
 | |
|     return Status::kSuccess;
 | |
|   }
 | |
| 
 | |
|   /// Gets the workspace size
 | |
|   static size_t get_workspace_size(Arguments const &args) {
 | |
| 
 | |
|     size_t bytes = 0;
 | |
| 
 | |
|     if (kSplitKSerial && args.split_k_slices > 1) {
 | |
|       // Determine grid shape
 | |
|       ThreadblockSwizzle threadblock_swizzle;
 | |
| 
 | |
|       cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
 | |
|         args.problem_size, 
 | |
|         {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
 | |
|         args.split_k_slices);
 | |
| 
 | |
|       bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
 | |
|     }
 | |
| 
 | |
|     return bytes;
 | |
|   }
 | |
| 
 | |
|   /// Initializes GEMM state from arguments.
 | |
|   Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
 | |
| 
 | |
|     // Determine grid shape
 | |
|     ThreadblockSwizzle threadblock_swizzle;
 | |
| 
 | |
|     cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
 | |
|       args.problem_size, 
 | |
|       {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
 | |
|       args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices);
 | |
| 
 | |
|     if (kSplitKSerial) {
 | |
|       if (args.split_k_slices > 1) {
 | |
|         if (!workspace) {
 | |
|           return Status::kErrorWorkspaceNull;
 | |
|         }
 | |
| 
 | |
|         size_t bytes = get_workspace_size(args);
 | |
|       
 | |
|         cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
 | |
| 
 | |
|         if (result != cudaSuccess) {
 | |
|           return Status::kErrorInternal;
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|     else {
 | |
| 
 | |
|       if (args.split_k_slices > 1) {
 | |
|         return Status::kErrorInvalidProblem;
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     // Initialize the Params structure
 | |
|     params_ = typename DualGemmKernel::Params{
 | |
|       args.mode,
 | |
|       args.problem_size,
 | |
|       grid_shape,
 | |
|       args.ref_A0.non_const_ref(),
 | |
|       args.ref_B0.non_const_ref(),
 | |
|       args.ref_C0.non_const_ref(),
 | |
|       args.ref_D0,
 | |
|       args.ref_B1.non_const_ref(),
 | |
|       args.ref_C1.non_const_ref(),
 | |
|       args.ref_D1,
 | |
|       args.ref_D2,
 | |
|       args.epilogue0,
 | |
|       args.epilogue1,
 | |
|       args.epilogue2,
 | |
|       reinterpret_cast<int *>(workspace),
 | |
|       args.batch_stride_A,
 | |
|       args.batch_stride_B0,
 | |
|       args.batch_stride_B1,
 | |
|       args.batch_stride_C,
 | |
|       args.batch_stride_D,
 | |
|     };
 | |
| 
 | |
|     return Status::kSuccess;
 | |
|   }
 | |
| 
 | |
|   /// Lightweight update given a subset of arguments
 | |
|   Status update(Arguments const &args, void *workspace = nullptr) {
 | |
|     
 | |
|     if (kSplitKSerial && args.split_k_slices > 1) {  
 | |
|       if (!workspace) {
 | |
|         return Status::kErrorWorkspaceNull;
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
 | |
|     params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
 | |
|     params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
 | |
|     params_.ref_D0.reset(args.ref_D0.data());
 | |
|     params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
 | |
|     params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
 | |
|     params_.ref_D1.reset(args.ref_D1.data());
 | |
|     params_.ref_D2.reset(args.ref_D2.data());
 | |
|     params_.output_op_0 = args.epilogue0;
 | |
|     params_.output_op_1 = args.epilogue1;
 | |
|     params_.output_op_2 = args.epilogue2;
 | |
|     params_.semaphore = reinterpret_cast<int *>(workspace);
 | |
| 
 | |
|     return Status::kSuccess;
 | |
|   }
 | |
| 
 | |
|   /// Runs the kernel using initialized state.
 | |
|   Status run(cudaStream_t stream = nullptr) {
 | |
| 
 | |
|     ThreadblockSwizzle threadblock_swizzle;
 | |
| 
 | |
|     dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
 | |
|     dim3 block(DualGemmKernel::kThreadCount, 1, 1);
 | |
| 
 | |
|     cudaError_t result;
 | |
| 
 | |
|     int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage));
 | |
|     if (smem_size >= (48 << 10)) {
 | |
|       result = cudaFuncSetAttribute(Kernel<DualGemmKernel>,
 | |
|                                     cudaFuncAttributeMaxDynamicSharedMemorySize,
 | |
|                                     smem_size);
 | |
| 
 | |
|       if (result != cudaSuccess) {
 | |
|         return Status::kErrorInternal;
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     cutlass::Kernel<DualGemmKernel><<<grid, block, smem_size, stream>>>(params_);
 | |
| 
 | |
|     result = cudaGetLastError();
 | |
| 
 | |
|     return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
 | |
|   }
 | |
| 
 | |
|   /// Runs the kernel using initialized state.
 | |
|   Status operator()(cudaStream_t stream = nullptr) {
 | |
|     return run(stream);
 | |
|   }
 | |
| 
 | |
|   /// Runs the kernel using initialized state.
 | |
|   Status operator()(
 | |
|     Arguments const &args, 
 | |
|     void *workspace = nullptr, 
 | |
|     cudaStream_t stream = nullptr) {
 | |
|     
 | |
|     Status status = initialize(args, workspace, stream);
 | |
|     
 | |
|     if (status == Status::kSuccess) {
 | |
|       status = run(stream);
 | |
|     }
 | |
| 
 | |
|     return status;
 | |
|   }
 | |
| };
 | |
| 
 | |
| } // namespace device
 | |
| } // namespace gemm
 | |
| } // namespace cutlass
 | |
| 
 | |
| ////////////////////////////////////////////////////////////////////////////////
 |