
[README](/README.md#documentation) > **Efficient GEMM in CUDA**
# Efficient GEMM in CUDA
CUTLASS implements the hierarchically blocked structure described in
[CUTLASS: Fast Linear Algebra in CUDA C++](https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/)
and the [CUTLASS GTC2018 talk](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
## Hierarchical Structure
The basic triple loop nest computing matrix multiply may be blocked and tiled to match
concurrency in hardware, memory locality, and parallel programming models. In CUTLASS,
GEMM is mapped to NVIDIA GPUs with the structure illustrated by the following loop nest.
```c++
for (int cta_n = 0; cta_n <GemmN;cta_n+=CtaTileN){//foreachthreadblock_y}threadblock-levelconcurrency
for (int cta_m = 0; cta_m <GemmM;cta_m+=CtaTileM){//foreachthreadblock_x}
for (int cta_k = 0; cta_k <GemmK;cta_k+=CtaTileK){//"GEMMmainloop"-nounrolling
// - one iteration of this loop is one "stage"
//
for (int warp_n = 0; warp_n <CtaTileN;warp_n+=WarpTileN){//foreachwarp_y}warp-levelparallelism
for (int warp_m = 0; warp_m <CtaTileM;warp_m+=WarpTileM){//foreachwarp_x}
//
for (int warp_k = 0; warp_k <CtaTileK;warp_k+=MmaK){//fullyunrollacrossCtaTileK
// - one iteration of this loop is one "k Group"
//
for (int mma_k = 0; mma_k <WarpTileK;mma_k+=MmaK){//foreachmmainstruction}instruction-levelparallelism
for (int mma_n = 0; mma_n <WarpTileN;mma_n+=MmaN){//foreachmmainstruction}
for (int mma_m = 0; mma_m <WarpTileM;mma_m+=MmaM){//foreachmmainstruction}
//
mma_instruction(d, a, b, c); // TensorCore matrix computation
} // for mma_m
} // for mma_n
} // for mma_k
} // for warp_k
} // for warp_m
} // for warp_n
} // for cta_k
} // for cta_m
} // for cta_n
```
This tiled loop nest targets concurrency among
- threadblocks
- warps
- CUDA and Tensor Cores
and takes advantage of memory locality within
- shared memory
- registers
The flow of data within this structure is illustrated below.
This is the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a
nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a
level within the memory hierarchy, becoming increasingly finer moving left to right.

### Threadblock-level GEMM
Each threadblock computes its portion of the output GEMM by iteratively loading tiles of input
matrices and computing an accumulated matrix product. At the threadblock level, data is loaded from
global memory. The blocking strategy in general is key to achieving efficiency. However, there are
multiple conflicting goals that a programmer aims to achieve to strike a reasonable compromise. A
larger threadblock means fewer fetches from global memory, thereby ensuring that DRAM bandwidth
does not become a bottleneck.
However, large threadblock tiles may not match the dimensions of the problem well. If either the
GEMM _M_ or _N_ dimension is small, some threads within the threadblock may not perform meaningful
work, as the threadblock may be partially outside the bounds of the problem. If both _M_ and _N_
are small while _K_ is large, this scheme may launch relatively few threadblocks and fail to
fully utilize all multiprocessors within the GPU. Strategies to optimize performance for this case
are described in the section [Parallelized Reductions](efficient_gemm.md#parallelized-reductions)
which partition the GEMM K dimension across multiple threadblocks or multiple warps. These compute
matrix products in parallel which is then reduced to compute the result.
In CUTLASS, the dimensions of the threadblock tile are specified as `ThreadblockShape::{kM, kN, kK}`
and may be tuned to specialize the GEMM computation for the target processor and dimensions of
the GEMM problem.
### Warp-level GEMM
The warp-level GEMM maps to the warp-level parallelism within the CUDA execution model. Multiple
warps within a threadblock fetch data from shared memory into registers and perform computations.
Warp-level GEMMs may be implemented either by TensorCores issuing
- [CUTLASS: Fast Linear Algebra in CUDA C++](https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/)
- [CUTLASS: SOFTWARE PRIMITIVES FOR DENSE LINEAR ALGEBRA AT ALL LEVELS AND SCALES WITHIN CUDA](https://on-demand-gtc.gputechconf.com/gtcnew/sessionview.php?sessionName=s8854-cutlass%3a+software+primitives+for+dense+linear+algebra+at+all+levels+and+scales+within+cuda)
- [Programming Tensor Cores: NATIVE VOLTA TENSOR CORES WITH CUTLASS](https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9593-cutensor-high-performance-tensor-operations-in-cuda-v2.pdf)