 66d9cddc83
			
		
	
	
		66d9cddc83
		
			
		
	
	
	
	
		
			
			* New updates. * Minor profiler updates Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
		
			
				
	
	
		
			341 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			341 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| 
 | |
| /**
 | |
| This example shows how to use split-k version of matrix multiplication using functions and data
 | |
| structures provided by CUTLASS; which we run on a NVIDIA Volta GPU.
 | |
| 
 | |
| What is split-k?
 | |
| Consider a problem size of M = 128, N = 128, K = 4096. In this case, if my thread-block tile size (a
 | |
| tile can be viewed as a 2d matrix) is 128x128x4096, then we launch a singled a thread-block taking
 | |
| up a single SM of 84 SMs present on V100. Hence the efficiency of computation is really low. So, how
 | |
| to solve it? This is where split-k comes in. It is a way of partitioning K-dimension of matrix
 | |
| multiplication and distribute across multiple SMs and get better efficiency than single SM. In the
 | |
| above example, we can partition K-dimension with split-k factor of 16 i.e., thread-block tile size
 | |
| will be 128x128x256 and will be launching on 16 SMs. Once each thread-block computes their partial
 | |
| inner product (1/16th of output), they accumulate to single output matrix.
 | |
| 
 | |
| Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing
 | |
| high performance kernels at scale which works for multiple problem sizes with good abstractions is
 | |
| really hard. CUTLASS solves this problem by providing simplified abstractions to compose
 | |
| multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU
 | |
| easily.
 | |
| 
 | |
| CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp
 | |
| and thread-block level, they compute on their own tile-size with higher level of tile sizes being
 | |
| composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used
 | |
| to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute
 | |
| threadblock-tile (tile size computed by a threadblock).
 | |
| 
 | |
| In this example, we split variable initialization into
 | |
| 1. Setting up data properties : describes how matrices are laid out in the memory and how the kernel
 | |
| can view them (logical to physical mapping)
 | |
| 2. Setting up computation properties : describes how the above set matrices will be used to compute
 | |
| output of matrix multiplication.
 | |
| 
 | |
| First, we setup the data types of matrices A, B, C and D along with alpha, beta as the equation for
 | |
| GEMM is D = alpha * A * B + beta * C. In CUTLASS, the kernels first compute A * B and leaves the
 | |
| rest of the computation to end of the kernel as alpha * X + beta * C is a simple element-wise
 | |
| operation on X (A * B) and C. We call this as epilogue of kernel. Hence, we setup data types for
 | |
| alpha and beta to be equal to ElementComputeEpilogue = float. As we want to MMA instructions on
 | |
| Volta and they support only half-precision floating point (fp16 or half), we use data type for
 | |
| elements in input matrix A and B as cutlass::half_t. Volta also supports accumulation of partial dot
 | |
| product to fp32, which can store wider range of numbers, we use it as data type of output matrix
 | |
| elements and accumulation. We convey this to CUTLASS kernel by initializing template variables
 | |
| ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutlass::half_t),
 | |
| ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not
 | |
| enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do
 | |
| that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB
 | |
| to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C
 | |
| which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the
 | |
| data type of output ElementOutput (float), the number of elements per vector memory access (16),
 | |
| data type of accumulator (float) and data type of computation of linear combination (alpha * X +
 | |
| beta * C).
 | |
| 
 | |
| Now that we setup the properties of data, we have to setup properties of computation.
 | |
| 
 | |
| Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32,
 | |
| 64x64x4, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally
 | |
| deduce the amount of threads needed per thread-block, amount of shared memory, storing data in
 | |
| bank-conflict free manner, and ton of other variables required to compose, initialize and launch a
 | |
| high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from
 | |
| understanding and coding complicated hardware optimizations which can easily go wrong.
 | |
| 
 | |
| There are few more template variables initialized such as, which threadblock tile of output matrix
 | |
| is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on.
 | |
| 
 | |
| These are all put together to create a template variable which describes CUTLASS GEMM kernel using
 | |
| cutlass::gemm::device::GemmSplitKParallel template.
 | |
| 
 | |
| The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it.
 | |
| We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come
 | |
| in the way of learning CUTLASS.
 | |
| 
 | |
| Once all the matrices are initialized and filled with data, create arguments tuple to launch CUTLASS
 | |
| kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the
 | |
| important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space
 | |
| memory required by the kernel we instantiated. If yes, we create it and pass it along with other
 | |
| arguments created to initialize CUTLASS kernel then, the kernel is launched.
 | |
| 
 | |
| In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if
 | |
| the output from CUTLASS kernel is same as reference GEMM kernel.
 | |
| */
 | |
| 
 | |
| #include <iostream>
 | |
| 
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/gemm/device/gemm_splitk_parallel.h"
 | |
| #include "cutlass/util/host_tensor.h"
 | |
| #include "cutlass/util/reference/device/gemm.h"
 | |
| #include "cutlass/util/reference/host/tensor_compare.h"
 | |
| #include "cutlass/util/reference/host/tensor_copy.h"
 | |
| #include "cutlass/util/reference/host/tensor_fill.h"
 | |
| #include "cutlass/util/tensor_view_io.h"
 | |
| #include "helper.h"
 | |
| 
 | |
| // The code section below describes datatype for input, output matrices and computation between
 | |
| // elements in input matrices.
 | |
| using ElementAccumulator = float;                   // <- data type of accumulator
 | |
| using ElementComputeEpilogue = ElementAccumulator;  // <- data type of epilogue operations
 | |
| using ElementInputA = cutlass::half_t;              // <- data type of elements in input matrix A
 | |
| using ElementInputB = cutlass::half_t;              // <- data type of elements in input matrix B
 | |
| using ElementOutput = float;                        // <- data type of elements in output matrix D
 | |
| 
 | |
| // The code section below describes matrix layout of input and output matrices. Column Major for
 | |
| // Matrix A, Row Major for Matrix B and Row Major for Matrix C
 | |
| using LayoutInputA = cutlass::layout::ColumnMajor;
 | |
| using LayoutInputB = cutlass::layout::RowMajor;
 | |
| using LayoutOutput = cutlass::layout::RowMajor;
 | |
| 
 | |
| // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
 | |
| using MMAOp = cutlass::arch::OpClassTensorOp;
 | |
| 
 | |
| // This code section describes CUDA SM architecture number
 | |
| using SmArch = cutlass::arch::Sm70;
 | |
| 
 | |
| // This code section describes the tile size a thread block will compute
 | |
| using ShapeMMAThreadBlock =
 | |
|     cutlass::gemm::GemmShape<128, 128, 32>;  // <- threadblock tile M = 128, N = 128, K = 32
 | |
| // This code section describes tile size a warp will compute
 | |
| using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;  // <- warp tile M = 64, N = 64, K = 32
 | |
| // This code section describes the size of MMA op
 | |
| using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>;  // <- MMA Op tile M = 8, N = 8, K = 4
 | |
| 
 | |
| // This code section describes ?
 | |
| using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
 | |
|     ElementOutput,                                     // <- data type of output matrix
 | |
|     128 / cutlass::sizeof_bits<ElementOutput>::value,  // <- This is the number of elements per
 | |
|                                                        // vectorized memory access. For half
 | |
|                                                        // precision, it's 8 elements. This becomes
 | |
|                                                        // the vector width of math instructions in
 | |
|                                                        // epilogue too
 | |
|     ElementAccumulator,                                // <- data type of accumulator
 | |
|     ElementComputeEpilogue>;  // <- data type for alpha/beta in linear combination function
 | |
| 
 | |
| // Put all the created template variables to create GemmSplitKParallel template variable
 | |
| using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
 | |
|                                                        LayoutInputA,
 | |
|                                                        ElementInputB,
 | |
|                                                        LayoutInputB,
 | |
|                                                        ElementOutput,
 | |
|                                                        LayoutOutput,
 | |
|                                                        ElementAccumulator,
 | |
|                                                        MMAOp,
 | |
|                                                        SmArch,
 | |
|                                                        ShapeMMAThreadBlock,
 | |
|                                                        ShapeMMAWarp,
 | |
|                                                        ShapeMMAOp,
 | |
|                                                        EpilogueOp>;
 | |
| 
 | |
| int run() {
 | |
| 
 | |
|   cudaDeviceProp props;
 | |
| 
 | |
|   cudaError_t error = cudaGetDeviceProperties(&props, 0);
 | |
|   if (error != cudaSuccess) {
 | |
|     std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   if (props.major != 7) {
 | |
|     std::cerr << "Volta Tensor Ops must be run on a machine with compute capability of 70, 72, or 75."
 | |
|               << std::endl;
 | |
| 
 | |
|     // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
 | |
|     return 0;
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // Define problem size
 | |
|   //
 | |
| 
 | |
|   const int length_m = 5120;
 | |
|   const int length_n = 4096;
 | |
|   const int length_k = 4096;
 | |
| 
 | |
|   // Create a tuple of problem size for matrix multiplication
 | |
|   cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
 | |
| 
 | |
|   // Initialize tensors using CUTLASS helper functions
 | |
|   cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
 | |
|       problem_size.mk());  // <- Create matrix A with dimensions M x K
 | |
|   cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
 | |
|       problem_size.kn());  // <- Create matrix B with dimensions K x N
 | |
|   cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
 | |
|       problem_size.mn());  // <- Create matrix C with dimensions M x N
 | |
|   cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
 | |
|       problem_size.mn());  // <- Create matrix D with dimensions M x N used to store output from
 | |
|                            // CUTLASS kernel
 | |
|   cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
 | |
|       problem_size.mn());  // <- Create matrix D with dimensions M x N used to store output from
 | |
|                            // reference kernel
 | |
| 
 | |
|   // Fill input and output matrices on host using CUTLASS helper functions
 | |
|   cutlass::reference::host::TensorFillRandomUniform(
 | |
|       tensor_a.host_view(),
 | |
|       1,
 | |
|       ElementInputA(4),
 | |
|       ElementInputA(-4),
 | |
|       0);  // <- Fill matrix A on host with uniform-distribution random data
 | |
|   cutlass::reference::host::TensorFillRandomUniform(
 | |
|       tensor_b.host_view(),
 | |
|       1,
 | |
|       ElementInputB(4),
 | |
|       ElementInputB(-4),
 | |
|       0);  // <- Fill matrix B on host with uniform-distribution random data
 | |
|   cutlass::reference::host::TensorFillRandomUniform(
 | |
|       tensor_c.host_view(),
 | |
|       1,
 | |
|       ElementOutput(4),
 | |
|       ElementOutput(-4),
 | |
|       0);  // <- Fill matrix C on host with uniform-distribution random data
 | |
|   cutlass::reference::host::TensorFill(
 | |
|       tensor_d.host_view());  // <- fill matrix D on host with zeros
 | |
|   cutlass::reference::host::TensorFill(
 | |
|       tensor_ref_d.host_view());  // <- fill matrix D for reference on host with zeros
 | |
| 
 | |
|   // Copy data from host to GPU
 | |
|   tensor_a.sync_device();
 | |
|   tensor_b.sync_device();
 | |
|   tensor_c.sync_device();
 | |
|   tensor_d.sync_device();
 | |
|   tensor_ref_d.sync_device();
 | |
| 
 | |
|   // Initialize alpha and beta for dot product computation
 | |
|   ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
 | |
|   ElementComputeEpilogue beta = ElementComputeEpilogue(0);
 | |
| 
 | |
|   // Split K dimension into 16 partitions
 | |
|   int split_k_slices = 16;
 | |
| 
 | |
|   // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
 | |
|   // instantiated CUTLASS kernel
 | |
|   typename Gemm::Arguments arguments{problem_size,  // <- problem size of matrix multiplication
 | |
|                                      tensor_a.device_ref(),  // <- reference to matrix A on device
 | |
|                                      tensor_b.device_ref(),  // <- reference to matrix B on device
 | |
|                                      tensor_c.device_ref(),  // <- reference to matrix C on device
 | |
|                                      tensor_d.device_ref(),  // <- reference to matrix D on device
 | |
|                                      {alpha, beta},          // <- tuple of alpha and beta
 | |
|                                      split_k_slices};        // <- k-dimension split factor
 | |
| 
 | |
|   // Using the arguments, query for extra workspace required for matrix multiplication computation
 | |
|   size_t workspace_size = Gemm::get_workspace_size(arguments);
 | |
| 
 | |
|   // Allocate workspace memory
 | |
|   cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
 | |
| 
 | |
|   // Instantiate CUTLASS kernel depending on templates
 | |
|   Gemm gemm_op;
 | |
| 
 | |
|   // Initialize CUTLASS kernel with arguments and workspace pointer
 | |
|   cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
 | |
|   CUTLASS_CHECK(status);
 | |
| 
 | |
|   // Launch initialized CUTLASS kernel
 | |
|   status = gemm_op();
 | |
|   CUTLASS_CHECK(status);
 | |
| 
 | |
|   // Create instantiation for device reference gemm kernel
 | |
|   cutlass::reference::device::Gemm<ElementInputA,
 | |
|                                    LayoutInputA,
 | |
|                                    ElementInputB,
 | |
|                                    LayoutInputB,
 | |
|                                    ElementOutput,
 | |
|                                    LayoutOutput,
 | |
|                                    ElementComputeEpilogue,
 | |
|                                    ElementComputeEpilogue>
 | |
|       gemm_device;
 | |
| 
 | |
|   // Launch device reference gemm kernel
 | |
|   gemm_device(problem_size,
 | |
|               alpha,
 | |
|               tensor_a.device_ref(),
 | |
|               tensor_b.device_ref(),
 | |
|               beta,
 | |
|               tensor_c.device_ref(),
 | |
|               tensor_ref_d.device_ref());
 | |
| 
 | |
|   // Wait for kernels to finish
 | |
|   cudaDeviceSynchronize();
 | |
| 
 | |
|   // Copy output data from CUTLASS and reference kernel to host for comparison
 | |
|   tensor_d.sync_host();
 | |
|   tensor_ref_d.sync_host();
 | |
| 
 | |
|   // Check if output from CUTLASS kernel and reference kernel are equal or not
 | |
|   bool passed = cutlass::reference::host::TensorEquals(
 | |
|     tensor_d.host_view(),
 | |
|     tensor_ref_d.host_view());
 | |
| 
 | |
|   std::cout << (passed ? "Passed" : "Failed") << std::endl;
 | |
| 
 | |
|   return (passed ? 0  : -1);
 | |
| }
 | |
| 
 | |
| int main() {
 | |
| 
 | |
|   //
 | |
|   // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
 | |
|   //
 | |
|   // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
 | |
|   //
 | |
|   if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
 | |
|     std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
 | |
| 
 | |
|     // Returning zero, so this test passes when built with older CUDA Toolkits. Its action are no-op.
 | |
|     return 0;
 | |
|   }
 | |
|   else {
 | |
|     return run();
 | |
|   }
 | |
| }
 | |
| 
 |