 e5f3caf145
			
		
	
	
		e5f3caf145
		
			
		
	
	
	
	
		
			
			* Fix README * Improve README --------- Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
		
			
				
	
	
		
			688 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
			
		
		
	
	
			688 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
| 
 | |
| 
 | |
| [README](../../README.md#documentation) > **Quick Start**
 | |
| 
 | |
| # Quickstart
 | |
| 
 | |
| ## Prerequisites
 | |
| 
 | |
| CUTLASS requires:
 | |
| - NVIDIA CUDA Toolkit (11.4 or later required, [12.0](https://developer.nvidia.com/cuda-toolkit) recommended)
 | |
| - CMake 3.18+
 | |
| - host compiler supporting C++17 or greater (minimum g++ 7.5.0)
 | |
| - Python 3.6+
 | |
| 
 | |
| CUTLASS may be optionally compiled and linked with
 | |
| - cuBLAS
 | |
| - cuDNN v7.6 or later
 | |
| 
 | |
| ## Initial build steps
 | |
| 
 | |
| Construct a build directory and run CMake.
 | |
| ```bash
 | |
| $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
 | |
| 
 | |
| $ mkdir build && cd build
 | |
| 
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=90a             # compiles for NVIDIA Hopper GPU architecture
 | |
| ```
 | |
| 
 | |
| If your goal is strictly to build only the CUTLASS Profiler and to minimize compilation time, we suggest
 | |
| executing the following CMake command in an empty `build/` directory.
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_UNITY_BUILD_ENABLED=ON
 | |
| ```
 | |
| 
 | |
| This reduces overall compilation time by excluding unit tests and enabling the unity build.
 | |
| 
 | |
| You may reduce build times by compiling only certain operations by setting the `CUTLASS_LIBRARY_OPERATIONS` flag as shown below,
 | |
| executed from an empty `build/` directory. This only compiles 2-D convolution kernels.
 | |
| 
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_OPERATIONS=conv2d
 | |
| ```
 | |
| 
 | |
| You may also filter kernels by name by supplying a filter string with flag `CUTLASS_LIBRARY_KERNELS`. For example the below command selects only CUTLASS-3 kernels.
 | |
| 
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=cutlass3x*
 | |
| ```
 | |
| See more examples on selectively compiling CUTLASS GEMM and convolution kernels [here](quickstart.md#example-cmake-commands).
 | |
| 
 | |
| You may explicitly exclude cuBLAS and cuDNN as dependencies with the following CMake flags.
 | |
| - `-DCUTLASS_ENABLE_CUBLAS=OFF`
 | |
| - `-DCUTLASS_ENABLE_CUDNN=OFF`
 | |
| 
 | |
| 
 | |
| ## Build and run the CUTLASS Profiler
 | |
| 
 | |
| From the `build/` directory created above, compile the CUTLASS Profiler.
 | |
| ```bash
 | |
| $ make cutlass_profiler -j12
 | |
| ```
 | |
| 
 | |
| Then execute the CUTLASS Profiler computing GEMM, execute the following command.
 | |
| ```bash
 | |
| $ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=4352 --n=4096 --k=4096
 | |
| 
 | |
| =============================
 | |
|   Problem ID: 1
 | |
| 
 | |
|     Provider: CUTLASS
 | |
|    Operation: cutlass_simt_sgemm_128x128_nn
 | |
| 
 | |
|  Disposition: Passed
 | |
|       Status: Success
 | |
| 
 | |
|    Arguments:  --m=4352 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0  \
 | |
|                --split_k_slices=1 --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8  \
 | |
|                --stages=2 --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50  \
 | |
|                --max_cc=1024
 | |
| 
 | |
|        Bytes: 52428800  bytes
 | |
|        FLOPs: 146064539648  flops
 | |
| 
 | |
|      Runtime: 10.5424  ms
 | |
|       Memory: 4.63158 GiB/s
 | |
| 
 | |
|         Math: 13854.9 GFLOP/s
 | |
| ```
 | |
| 
 | |
| To execute the CUTLASS Profiler for convolution, run the following example.
 | |
| ```bash
 | |
| $ ./tools/profiler/cutlass_profiler --kernels=s1688fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --pad_h=1 --pad_w=1
 | |
| ```
 | |
| 
 | |
| To execute all CUTLASS 2-D convolution operators, execute the following.
 | |
| ```bash
 | |
| $ ./tools/profiler/cutlass_profiler --operation=conv2d --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3
 | |
| 
 | |
| 
 | |
| =============================
 | |
|   Problem ID: 1
 | |
| 
 | |
|         Provider: CUTLASS
 | |
|    OperationKind: conv2d
 | |
|        Operation: cutlass_simt_sfprop_optimized_128x128_8x2_nhwc
 | |
| 
 | |
|           Status: Success
 | |
|     Verification: ON
 | |
|      Disposition: Passed
 | |
| 
 | |
| reference_device: Passed
 | |
| 
 | |
|        Arguments: --conv_kind=fprop --n=8 --h=224 --w=224 --c=128 --k=128 --r=3 --s=3 --p=224 --q=224 --pad_h=1 --pad_w=1  \
 | |
|                   --stride_h=1 --stride_w=1 --dilation_h=1 --dilation_w=1 --Activation=f32:nhwc --Filter=f32:nhwc --Output=f32:nhwc  \
 | |
|                   --conv_mode=cross --iterator_algorithm=optimized --alpha=1 --beta=0 --split_k_mode=serial --split_k_slices=1  \
 | |
|                   --eq_gemm_provider=none --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4  \
 | |
|                   --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
 | |
| 
 | |
|            Bytes: 2055798784  bytes
 | |
|            FLOPs: 118482796544  flops
 | |
| 
 | |
|          Runtime: 8.13237  ms
 | |
|           Memory: 235.431 GiB/s
 | |
| 
 | |
|             Math: 14569.3 GFLOP/s
 | |
| 
 | |
| ```
 | |
| 
 | |
| See [documentation for the CUTLASS Profiler](profiler.md) for more details.
 | |
| 
 | |
| ## Build and run CUTLASS Unit Tests
 | |
| 
 | |
| From the `build/` directory created above, simply build the target `test_unit` to compile and run
 | |
| all unit tests.
 | |
| 
 | |
| ```bash
 | |
| $ make test_unit -j
 | |
| ...
 | |
| ...
 | |
| ...
 | |
| [----------] Global test environment tear-down
 | |
| [==========] 946 tests from 57 test cases ran. (10812 ms total)
 | |
| [  PASSED  ] 946 tests.
 | |
| $
 | |
| ```
 | |
| The exact number of tests run is subject to change as we add more functionality.
 | |
| 
 | |
| No tests should fail. Unit tests automatically construct the appropriate runtime filters
 | |
| to avoid executing on architectures that do not support all features under test.
 | |
| 
 | |
| The unit tests are arranged hierarchically mirroring the CUTLASS Template Library. This enables
 | |
| parallelism in building and running tests as well as reducing compilation times when a specific
 | |
| set of tests are desired.
 | |
| 
 | |
| For example, the following executes strictly the warp-level GEMM tests.
 | |
| ```bash
 | |
| $ make test_unit_gemm_warp -j
 | |
| ...
 | |
| ...
 | |
| [----------] 3 tests from SM75_warp_gemm_tensor_op_congruous_f16
 | |
| [ RUN      ] SM75_warp_gemm_tensor_op_congruous_f16.128x128x8_32x128x8_16x8x8
 | |
| [       OK ] SM75_warp_gemm_tensor_op_congruous_f16.128x128x8_32x128x8_16x8x8 (0 ms)
 | |
| [ RUN      ] SM75_warp_gemm_tensor_op_congruous_f16.128x128x32_64x64x32_16x8x8
 | |
| [       OK ] SM75_warp_gemm_tensor_op_congruous_f16.128x128x32_64x64x32_16x8x8 (2 ms)
 | |
| [ RUN      ] SM75_warp_gemm_tensor_op_congruous_f16.128x128x32_32x32x32_16x8x8
 | |
| [       OK ] SM75_warp_gemm_tensor_op_congruous_f16.128x128x32_32x32x32_16x8x8 (1 ms)
 | |
| [----------] 3 tests from SM75_warp_gemm_tensor_op_congruous_f16 (3 ms total)
 | |
| ...
 | |
| ...
 | |
| [----------] Global test environment tear-down
 | |
| [==========] 104 tests from 32 test cases ran. (294 ms total)
 | |
| [  PASSED  ] 104 tests.
 | |
| [100%] Built target test_unit_gemm_warp
 | |
| ```
 | |
| 
 | |
| ## Building for Multiple Architectures
 | |
| 
 | |
| To minimize compilation time, specific GPU architectures can be enabled via the CMake command,
 | |
| selected by [CUDA Compute Capability.](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities)
 | |
| 
 | |
| **NVIDIA Hopper Architecture.**
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=90a              # compiles for NVIDIA Hopper GPU architecture
 | |
| ```
 | |
| 
 | |
| **NVIDIA Ampere Architecture.**
 | |
| 
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=80               # compiles for NVIDIA Ampere GPU architecture
 | |
| ```
 | |
| 
 | |
| **NVIDIA Turing Architecture.**
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=75               # compiles for NVIDIA Turing GPU architecture
 | |
| ```
 | |
| 
 | |
| **NVIDIA Volta Architecture.**
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=70               # compiles for NVIDIA Volta GPU architecture
 | |
| ```
 | |
| 
 | |
| **NVIDIA Pascal Architecture.**
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS="60;61"          # compiles for NVIDIA Pascal GPU architecture
 | |
| ```
 | |
| 
 | |
| **NVIDIA Maxwell Architecture.**
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS="50;53"          # compiles for NVIDIA Maxwell GPU architecture
 | |
| ```
 | |
| 
 | |
| ## Using CUTLASS within other applications
 | |
| 
 | |
| Applications should list [`/include`](/include) within their include paths. They must be
 | |
| compiled as C++17 or greater.
 | |
| 
 | |
| **Example:** print the contents of a variable storing half-precision data.
 | |
| ```c++
 | |
| #include <iostream>
 | |
| #include <cutlass/cutlass.h>
 | |
| #include <cutlass/numeric_types.h>
 | |
| #include <cutlass/core_io.h>
 | |
| 
 | |
| int main() {
 | |
| 
 | |
|   cutlass::half_t x = 2.25_hf;
 | |
| 
 | |
|   std::cout << x << std::endl;
 | |
| 
 | |
|   return 0;
 | |
| }
 | |
| ```
 | |
| 
 | |
| ## Launching a GEMM kernel in CUDA
 | |
| 
 | |
| **Example:** launch a mixed-precision GEMM targeting Turing Tensor Cores.
 | |
| 
 | |
| _Note, this example uses CUTLASS Utilities. Be sure `tools/util/include` is listed as an include path._
 | |
| ```c++
 | |
| #include <cutlass/numeric_types.h>
 | |
| #include <cutlass/gemm/device/gemm.h>
 | |
| 
 | |
| #include <cutlass/util/host_tensor.h>
 | |
| 
 | |
| int main() {
 | |
| 
 | |
|   // Define the GEMM operation
 | |
|   using Gemm = cutlass::gemm::device::Gemm<
 | |
|     cutlass::half_t,                           // ElementA
 | |
|     cutlass::layout::ColumnMajor,              // LayoutA
 | |
|     cutlass::half_t,                           // ElementB
 | |
|     cutlass::layout::ColumnMajor,              // LayoutB
 | |
|     cutlass::half_t,                           // ElementOutput
 | |
|     cutlass::layout::ColumnMajor,              // LayoutOutput
 | |
|     float,                                     // ElementAccumulator
 | |
|     cutlass::arch::OpClassTensorOp,            // tag indicating Tensor Cores
 | |
|     cutlass::arch::Sm75                        // tag indicating target GPU compute architecture
 | |
|   >;
 | |
| 
 | |
|   Gemm gemm_op;
 | |
|   cutlass::Status status;
 | |
| 
 | |
|   //
 | |
|   // Define the problem size
 | |
|   //
 | |
|   int M = 512;
 | |
|   int N = 256;
 | |
|   int K = 128;
 | |
| 
 | |
|   float alpha = 1.25f;
 | |
|   float beta = -1.25f;
 | |
| 
 | |
|   //
 | |
|   // Allocate device memory
 | |
|   //
 | |
| 
 | |
|   cutlass::HostTensor<cutlass::half_t, cutlass::layout::ColumnMajor> A({M, K});
 | |
|   cutlass::HostTensor<cutlass::half_t, cutlass::layout::ColumnMajor> B({K, N});
 | |
|   cutlass::HostTensor<cutlass::half_t, cutlass::layout::ColumnMajor> C({M, N});
 | |
| 
 | |
|   cutlass::half_t const *ptrA = A.device_data();
 | |
|   cutlass::half_t const *ptrB = B.device_data();
 | |
|   cutlass::half_t const *ptrC = C.device_data();
 | |
|   cutlass::half_t       *ptrD = C.device_data();
 | |
| 
 | |
|   int lda = A.device_ref().stride(0);
 | |
|   int ldb = B.device_ref().stride(0);
 | |
|   int ldc = C.device_ref().stride(0);
 | |
|   int ldd = C.device_ref().stride(0);
 | |
|   //
 | |
|   // Launch GEMM on the device
 | |
|   //
 | |
| 
 | |
|   status = gemm_op({
 | |
|     {M, N, K},
 | |
|     {ptrA, lda},            // TensorRef to A device tensor
 | |
|     {ptrB, ldb},            // TensorRef to B device tensor
 | |
|     {ptrC, ldc},            // TensorRef to C device tensor
 | |
|     {ptrD, ldd},            // TensorRef to D device tensor - may be the same as C
 | |
|     {alpha, beta}           // epilogue operation arguments
 | |
|   });
 | |
| 
 | |
|   if (status != cutlass::Status::kSuccess) {
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   return 0;
 | |
| }
 | |
| ```
 | |
| 
 | |
| Note, the above could be simplified as follows using helper methods defined in `HostTensor`.
 | |
| ```c++
 | |
|   cutlass::HostTensor<cutlass::half_t, cutlass::layout::ColumnMajor> A({M, K});
 | |
|   cutlass::HostTensor<cutlass::half_t, cutlass::layout::ColumnMajor> B({K, N});
 | |
|   cutlass::HostTensor<cutlass::half_t, cutlass::layout::ColumnMajor> C({M, N});
 | |
| 
 | |
|   //
 | |
|   // Use the TensorRef returned by HostTensor::device_ref().
 | |
|   //
 | |
| 
 | |
|   status = gemm_op({
 | |
|     {M, N, K},
 | |
|     A.device_ref(),            // TensorRef to A device tensor
 | |
|     B.device_ref(),            // TensorRef to B device tensor
 | |
|     C.device_ref(),            // TensorRef to C device tensor
 | |
|     C.device_ref(),            // TensorRef to D device tensor - may be the same as C
 | |
|     {alpha, beta}              // epilogue operation arguments
 | |
|   });
 | |
| ```
 | |
| 
 | |
| ## Launching a GEMM kernel using CUTLASS 3.0 or newer
 | |
| 
 | |
| **Example:** launch a mixed-precision GEMM targeting Hopper Tensor Cores.
 | |
| 
 | |
| ```c++
 | |
| #include "cutlass/cutlass.h"
 | |
| #include "cutlass/epilogue/collective/default_epilogue.hpp"
 | |
| #include "cutlass/epilogue/thread/linear_combination.h"
 | |
| #include "cutlass/gemm/collective/collective_builder.hpp"
 | |
| #include "cutlass/gemm/device/gemm_universal_adapter.h"
 | |
| #include "cutlass/gemm/kernel/gemm_universal.hpp"
 | |
| 
 | |
| #include "cutlass/util/host_tensor.h"
 | |
| #include "cutlass/util/packed_stride.hpp"
 | |
| 
 | |
| using namespace cute;
 | |
| 
 | |
| int main(int argc, char const **args) {
 | |
| 
 | |
|   // A matrix configuration
 | |
|   using         ElementA    = cutlass::half_t;                                // Element type for A matrix operand
 | |
|   using         LayoutA     = cutlass::layout::RowMajor;                      // Layout type for A matrix operand
 | |
|   constexpr int AlignmentA  = 128 / cutlass::sizeof_bits<ElementA>::value;    // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
 | |
| 
 | |
|   // B matrix configuration
 | |
|   using         ElementB    = cutlass::half_t;                                // Element type for B matrix operand
 | |
|   using         LayoutB     = cutlass::layout::ColumnMajor;                   // Layout type for B matrix operand
 | |
|   constexpr int AlignmentB  = 128 / cutlass::sizeof_bits<ElementB>::value;    // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
 | |
| 
 | |
|   // C/D matrix configuration
 | |
|   using         ElementC    = cutlass::half_t;                                // Element type for C and D matrix operands
 | |
|   using         LayoutC     = cutlass::layout::ColumnMajor;                   // Layout type for C and D matrix operands
 | |
| 
 | |
|   // Core kernel configurations
 | |
|   using ElementAccumulator  = float;                                          // Element type for internal accumulation
 | |
|   using ArchTag             = cutlass::arch::Sm90;                            // Tag indicating the minimum SM that supports the intended feature
 | |
|   using OperatorClass       = cutlass::arch::OpClassTensorOp;                 // Operator class tag
 | |
|   using TilesShape          = Shape<_128,_128,_64>;                           // Threadblock-level tile size
 | |
|   using ClusterShape        = Shape<_1,_2,_1>;                                // Shape of the threadblocks in a cluster
 | |
|   using StageCountType = cutlass::gemm::collective::StageCountAuto;           // Stage count maximized based on the tile size
 | |
|   using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;       // Kernel to launch based on the default setting in the Collective Builder
 | |
| 
 | |
|   using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
 | |
|       ArchTag, OperatorClass,
 | |
|       ElementA, LayoutA, AlignmentA,
 | |
|       ElementB, LayoutB, AlignmentB,
 | |
|       ElementAccumulator,
 | |
|       TilesShape, ClusterShape,
 | |
|       cutlass::gemm::collective::StageCountAuto,
 | |
|       cutlass::gemm::collective::KernelScheduleAuto
 | |
|     >::CollectiveOp;
 | |
| 
 | |
|   using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
 | |
|       cutlass::gemm::TagToStrideC_t<LayoutC>,
 | |
|       cutlass::gemm::TagToStrideC_t<LayoutC>,
 | |
|       cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;
 | |
| 
 | |
|   using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
 | |
|       Shape<int,int,int>, // Indicates ProblemShape
 | |
|       CollectiveMainloop,
 | |
|       CollectiveEpilogue
 | |
|   >;
 | |
| 
 | |
|   using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
 | |
| 
 | |
|   Gemm gemm_op;
 | |
|   cutlass::Status status;
 | |
| 
 | |
|   //
 | |
|   // Define the problem size
 | |
|   //
 | |
| 
 | |
|   int M = 512;
 | |
|   int N = 256;
 | |
|   int K = 128;
 | |
| 
 | |
|   float alpha = 1.25f;
 | |
|   float beta = -1.25f;
 | |
| 
 | |
|   //
 | |
|   // Allocate device memory
 | |
|   //
 | |
| 
 | |
|   cutlass::DeviceAllocation<typename Gemm::ElementA> block_A;
 | |
|   cutlass::DeviceAllocation<typename Gemm::ElementB> block_B;
 | |
|   cutlass::DeviceAllocation<typename Gemm::ElementC> block_C;
 | |
|   cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput> block_D;
 | |
| 
 | |
|   using StrideA = typename Gemm::GemmKernel::StrideA;
 | |
|   using StrideB = typename Gemm::GemmKernel::StrideB;
 | |
|   using StrideC = typename Gemm::GemmKernel::StrideC;
 | |
|   using StrideD = typename Gemm::GemmKernel::StrideD;
 | |
| 
 | |
|   StrideA stride_A;
 | |
|   StrideB stride_B;
 | |
|   StrideC stride_C;
 | |
|   StrideD stride_D;
 | |
| 
 | |
|   stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
 | |
|   stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
 | |
|   stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1});
 | |
|   stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});
 | |
| 
 | |
|   block_A.reset(M * K);
 | |
|   block_B.reset(K * N);
 | |
|   block_C.reset(M * N);
 | |
|   block_D.reset(M * N);
 | |
| 
 | |
|   //
 | |
|   // Launch GEMM on the device
 | |
|   //
 | |
| 
 | |
|   status = gemm_op({
 | |
|     cutlass::gemm::GemmUniversalMode::kGemm,
 | |
|     {M, N, K},
 | |
|     block_A.get(),
 | |
|     stride_A,
 | |
|     block_B.get(),
 | |
|     stride_B,
 | |
|     {block_C.get(), stride_C, block_D.get(), stride_D, {alpha, beta}}
 | |
|   });
 | |
| 
 | |
|   if (status != cutlass::Status::kSuccess) {
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   return 0;
 | |
| }
 | |
| ```
 | |
| 
 | |
| # CUTLASS Library
 | |
| 
 | |
| The [CUTLASS Library](/tools/library) defines an API for managing and executing collections of compiled
 | |
| kernel instances and launching them from host code without template instantiations in client code.
 | |
| 
 | |
| The host-side launch API is designed to be analogous to BLAS implementations for convenience, though its
 | |
| kernel selection procedure is intended only to be functionally sufficient. It may not launch the
 | |
| optimal tile size for a given problem. It chooses the first available kernel whose data types,
 | |
| layouts, and alignment constraints satisfy the given problem. Kernel instances and a data structure
 | |
| describing them are completely available to client applications which may choose to implement their
 | |
| own selection logic.
 | |
| 
 | |
| [cuBLAS](https://developer.nvidia.com/cublas) offers the best performance and functional coverage
 | |
| for dense matrix computations on NVIDIA GPUs.
 | |
| 
 | |
| The CUTLASS Library is used by the CUTLASS Profiler to manage kernel instances, and it is also used
 | |
| by several SDK examples.
 | |
| 
 | |
| * [10_planar_complex](/examples/10_planar_complex/planar_complex.cu)
 | |
| * [11_planar_complex_array](/examples/11_planar_complex_array/planar_complex_array.cu)
 | |
| 
 | |
| The CUTLASS Library defines enumerated types describing numeric data types, matrix and tensor
 | |
| layouts, math operation classes, complex transformations, and more.
 | |
| 
 | |
| Client applications should specify [`tools/library/include`](/tools/library/include) in their
 | |
| include paths and link against libcutlas_lib.so.
 | |
| 
 | |
| The CUTLASS SDK example [10_planar_complex](/examples/10_planar_complex/CMakeLists.txt) specifies
 | |
| its dependency on the CUTLASS Library with the following CMake command.
 | |
| ```
 | |
| target_link_libraries(
 | |
|   10_planar_complex
 | |
|   PRIVATE
 | |
|   cutlass_lib
 | |
|   cutlass_tools_util_includes
 | |
| )
 | |
| ```
 | |
| 
 | |
| A sample kernel launch from host-side C++ is shown as follows.
 | |
| 
 | |
| ```c++
 | |
| #include "cutlass/library/library.h"
 | |
| #include "cutlass/library/handle.h"
 | |
| 
 | |
| int main() {
 | |
| 
 | |
|   //
 | |
|   // Define the problem size
 | |
|   //
 | |
|   int M = 512;
 | |
|   int N = 256;
 | |
|   int K = 128;
 | |
| 
 | |
|   float alpha = 1.25f;
 | |
|   float beta = -1.25f;
 | |
| 
 | |
|   //
 | |
|   // Allocate device memory
 | |
|   //
 | |
| 
 | |
|   cutlass::HostTensor<float, cutlass::layout::ColumnMajor> A({M, K});
 | |
|   cutlass::HostTensor<float, cutlass::layout::ColumnMajor> B({K, N});
 | |
|   cutlass::HostTensor<float, cutlass::layout::ColumnMajor> C({M, N});
 | |
| 
 | |
|   float const *ptrA = A.device_data();
 | |
|   float const *ptrB = B.device_data();
 | |
|   float const *ptrC = C.device_data();
 | |
|   float       *ptrD = C.device_data();
 | |
| 
 | |
|   int lda = A.device_ref().stride(0);
 | |
|   int ldb = B.device_ref().stride(0);
 | |
|   int ldc = C.device_ref().stride(0);
 | |
|   int ldd = D.device_ref().stride(0);
 | |
| 
 | |
|   //
 | |
|   // CUTLASS Library call to execute device GEMM
 | |
|   //
 | |
| 
 | |
|   cutlass::library::Handle handle;
 | |
| 
 | |
|   //
 | |
|   // Launch GEMM on CUDA device.
 | |
|   //
 | |
| 
 | |
|   cutlass::Status status = handle.gemm(
 | |
|     M,
 | |
|     N,
 | |
|     K,
 | |
| 
 | |
|     cutlass::library::NumericTypeID::kF32,          // data type of internal accumulation
 | |
|     cutlass::library::NumericTypeID::kF32,          // data type of alpha/beta scalars
 | |
| 
 | |
|     &alpha,                                         // pointer to alpha scalar
 | |
| 
 | |
|     cutlass::library::NumericTypeID::kF32,          // data type of A matrix
 | |
|     cutlass::library::LayoutTypeID::kColumnMajor,   // layout of A matrix
 | |
|     ptrA,                                           // pointer to A matrix in device memory
 | |
|     lda,                                            // leading dimension of A matrix
 | |
| 
 | |
|     cutlass::library::NumericTypeID::kF32,          // data type of B matrix
 | |
|     cutlass::library::LayoutTypeID::kColumnMajor,   // layout of B matrix
 | |
|     ptrB,                                           // pointer to B matrix in device memory
 | |
|     ldb,                                            // leading dimension of B matrix
 | |
| 
 | |
|     &beta,                                          // pointer to beta scalar
 | |
| 
 | |
|     cutlass::library::NumericTypeID::kF32,          // data type of C and D matrix
 | |
| 
 | |
|     ptrC,                                           // pointer to C matrix in device memory
 | |
|     ldc,                                            // leading dimension fo C matrix
 | |
| 
 | |
|     ptrD,                                           // pointer to D matrix in device memory
 | |
|     ldd                                             // leading dimension of D matrix
 | |
|   );
 | |
| 
 | |
|   if (status != cutlass::Status::kSuccess) {
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   return 0;
 | |
| }
 | |
| ```
 | |
| 
 | |
| # Example CMake Commands
 | |
| 
 | |
| To instantiate all operations supporting all tile sizes, data types, and alignment constraints, specify
 | |
| `-DCUTLASS_LIBRARY_KERNELS=all` when running `cmake`.
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=all
 | |
| ```
 | |
| The above command line generates about twenty thousand kernels targeting NVIDIA Ampere, Turing, and Volta architectures.
 | |
| Compiling thousands of kernels for three different architectures is time-consuming. Additionally, this would also result
 | |
| in a large binary size and on some platforms linker to fail on building the library.
 | |
| 
 | |
| Enabling the "unity build" instantiates multiple kernel instances in each compilation unit, thereby reducing binary size
 | |
| and avoiding linker limitations on some platforms.
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all -DCUTLASS_UNITY_BUILD_ENABLED=ON
 | |
| ```
 | |
| 
 | |
| It is advised to only compile CUTLASS kernels for NVIDIA architectures one plans on running. Furthermore, kernels
 | |
| can be selectively included in the CUTLASS Library by specifying filter strings and wildcard characters when executing CMake.
 | |
| 
 | |
| Several examples are defined below for convenience. They may be combined as a comma-delimited list.
 | |
| Compling only the kernels desired reduces compilation time.
 | |
| 
 | |
| 
 | |
| ## GEMM CMake Examples
 | |
| **Example.** All GEMM kernels targeting NVIDIA Ampere Tensor Cores.
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS=tensorop*gemm
 | |
| ```
 | |
| 
 | |
| **Example.** All GEMM kernels targeting NVIDIA Turing Tensor Cores.
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=tensorop*gemm
 | |
| ```
 | |
| 
 | |
| **Example.** All GEMM kernels with FP32 accumulation targeting NVIDIA Ampere, Turing, and Volta architectures.
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=s*gemm
 | |
| ```
 | |
| 
 | |
| **Example.** All kernels which expect A and B to be column-major or row-major targeting NVIDIA Ampere, Turing, and Volta architectures.
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=gemm*nn,gemm*tt
 | |
| ```
 | |
| 
 | |
| **Example.** All planar complex GEMM variants targeting NVIDIA Ampere, Turing, and Volta architectures.
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=planar_complex
 | |
| ```
 | |
| 
 | |
| ## Convolution CMake Examples
 | |
| **Example.** All convolution kernels targeting NVIDIA Ampere's 16816 Tensor Core operation
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS='80' -DCUTLASS_LIBRARY_KERNELS=s16816fprop,s16816dgrad,s16816wgrad
 | |
| ```
 | |
| 
 | |
| **Example.** All forward propagation (fprop) convolution kernels targeting CUDA Cores for multiple NVIDIA architectures
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS='50;60;61;70;75;80' -DCUTLASS_LIBRARY_KERNELS=sfprop
 | |
| ```
 | |
| 
 | |
| **Example.** All forward propagation (fprop) convolution kernels with FP32 accumulation and FP16 input targeting NVIDIA Ampere's 16816 Tensor Core operation
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS='80' -DCUTLASS_LIBRARY_KERNELS=s16816fprop_*_f16
 | |
| ```
 | |
| 
 | |
| **Example.** All backward weight gradient (wgrad) convolution kernels with FP32 accumulation, FP16 input, and optimized global memory iterator
 | |
| targeting NVIDIA Ampere, Turing, and Volta Tensor Core operations
 | |
| ```bash
 | |
| $ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=tensorop*s*wgrad_optimized_f16
 | |
| ```
 | |
| 
 | |
| # Copyright
 | |
| 
 | |
| Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
| SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| ```
 | |
|   Redistribution and use in source and binary forms, with or without
 | |
|   modification, are permitted provided that the following conditions are met:
 | |
| 
 | |
|   1. Redistributions of source code must retain the above copyright notice, this
 | |
|   list of conditions and the following disclaimer.
 | |
| 
 | |
|   2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|   this list of conditions and the following disclaimer in the documentation
 | |
|   and/or other materials provided with the distribution.
 | |
| 
 | |
|   3. Neither the name of the copyright holder nor the names of its
 | |
|   contributors may be used to endorse or promote products derived from
 | |
|   this software without specific prior written permission.
 | |
| 
 | |
|   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|   AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|   IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|   DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|   FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|   DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|   SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|   CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|   OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
| ```
 |