cutlass/media/docs/doxygen_mainpage.md
Manish Gupta 1ac4559d12
Cutlass 2.6 Update 1 (#301)
* cutlass 2.6 update

* remove debug prints
2021-07-27 17:58:30 -07:00

146 lines
6.3 KiB
Markdown

# CUTLASS 2.0
_CUTLASS 2.0 - November 2019_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
It incorporates strategies for hierarchical decomposition and data movement similar
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
reusable, modular software components abstracted by C++ template classes. These
thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
resulting flexibility simplifies their use as building blocks within custom kernels
and applications.
To support a wide variety of applications, CUTLASS provides extensive support for
mixed-precision computations, providing specialized data-movement and
multiply-accumulate abstractions for 8-bit integer, half-precision floating
point (FP16), single-precision floating point (FP32), and double-precision floating
point (FP64) types. Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply
operations for targeting the programmable, high-throughput _Tensor Cores_ implemented
by NVIDIA's Volta and Turing architectures.
# What's New in CUTLASS 2.0
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
- Better performance over 1.x, particularly for kernels targeting Turing Tensor Cores
- Robust and durable templates that reliably span the design space
- Encapsulated functionality that may be reusable in other contexts
# Example CUTLASS GEMM
The following illustrates an example function that defines a CUTLASS GEMM kernel
with single-precision inputs and outputs. This is an exercpt from the CUTLASS SDK
[basic_gemm example](https://github.com/NVIDIA/cutlass/tree/master/examples/00_basic_gemm/basic_gemm.cu).
~~~~~~~~~~~~~~~~~~~~~{.cpp}
//
// CUTLASS includes needed for single-precision GEMM kernel
//
// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class.
#include <cutlass/gemm/device/gemm.h>
/// Define a CUTLASS GEMM template and launch a GEMM kernel.
cudaError_t cutlass_sgemm_nn(
int M,
int N,
int K,
float alpha,
float const *A,
int lda,
float const *B,
int ldb,
float beta,
float *C,
int ldc) {
// Define type definition for single-precision CUTLASS GEMM with column-major
// input matrices and 128x128x8 threadblock tile size (chosen by default).
//
// To keep the interface manageable, several helpers are defined for plausible compositions
// including the following example for single-precision GEMM. Typical values are used as
// default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details.
//
// To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h`
using ColumnMajor = cutlass::layout::ColumnMajor;
using CutlassGemm = cutlass::gemm::device::Gemm<float, // Data-type of A matrix
ColumnMajor, // Layout of A matrix
float, // Data-type of B matrix
ColumnMajor, // Layout of B matrix
float, // Data-type of C matrix
ColumnMajor>; // Layout of C matrix
// Define a CUTLASS GEMM type
CutlassGemm gemm_operator;
// Construct the CUTLASS GEMM arguments object.
//
// One of CUTLASS's design patterns is to define gemm argument objects that are constructible
// in host code and passed to kernels by value. These may include pointers, strides, scalars,
// and other arguments needed by Gemm and its components.
//
// The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
// arguments to kernels and (2.) minimized initialization overhead on kernel entry.
//
CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{A, lda}, // Tensor-ref for source matrix A
{B, ldb}, // Tensor-ref for source matrix B
{C, ldc}, // Tensor-ref for source matrix C
{C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix)
{alpha, beta}); // Scalars used in the Epilogue
//
// Launch the CUTLASS GEMM kernel.
//
cutlass::Status status = gemm_operator(args);
//
// Return a cudaError_t if the CUTLASS GEMM operator returned an error code.
//
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
// Return success, if no errors were encountered.
return cudaSuccess;
}
~~~~~~~~~~~~~~~~~~~~~
# Copyright
Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted
provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of
conditions and the following disclaimer in the documentation and/or other materials
provided with the distribution.
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
to endorse or promote products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```