This enables 128-bit vector memory acceses which lead to efficient CUDA kernels. Smaller alignment is supported even on tensor cores by setting AlignmentA and AlignmentB in conv::kernel::DefaultConv2dFprop, but the performance is lower than 128-bit aligned tesnors.
CUTLASS defines the following CUDA C++ templates to implement Implicit GEMM Convolution which are described in greater detail in subsequent sections.
**Activations tile iterators** load the activations tile into registers. Two implementations are provided:
- [conv2d_fprop_activation_tile_access_iterator_analytic.h](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h) computes pointer deltas and masks analytically
- [conv2d_fprop_activation_tile_access_iterator_optimized.h](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h) optimizes iterating over global memory and
creating GEMM-A tile in shared memory.
**Filter tile iterators** load filters into registers. Similarly, two implementations are provided:
- [conv2d_fprop_filter_tile_access_iterator_analytic.h](/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h) computes pointer deltas and masks analytically
- [conv2d_fprop_filter_tile_access_iterator_optimized.h](/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h) optimizes iterating over global memory and
creating GEMM-B tile in shared memory.
The improvements covered by optimized iterators are:
- (a) Precomputing kernel-invariant pointer deltas on the host
- (b) Computing cta-invariant mask predicates on device-side iterator ctors
stores register-backed fragments to Shared Memory in permuted layouts.
**Warp-level GEMM** defined in [cutlass::gemm::warp::MmaTensorOp](/include/cutlass/gemm/warp/mma_tensor_op.h)
defines tile iterators to load from Shared Memory and issue math instructions to Turing Tensor Cores.
Further details are [described in here](/media/docs/gemm_api.md#warp-level-matrix-multiply-api).
**Epilogue** reorders accumulator elements among threads within a threadblock to efficiently update
the output tensor. It is implemented in [epilogue::threadblock::Epilogue](/include/cutlass/epilogue/threadblock/epilogue.h).
### Loading Activations and Filters
The Implicit GEMM Convolution algorithm partitions the GEMM _K_ dimension (of extent _CRS_) into
threadblock tiles and assigning each threadblock tile to one filter position and an interval
of channels. After iterating over all filter positions, the convolution algorithm advances to the
next interval of channels and proceeds from filter `r=0, s=0`.
The matrix product of one threadblock tile is computed per iteration of
the mainloop as described in the [CUTLASS GEMM implementation](/media/docs/efficient_gemm.md). To
summarize, the threadblock tile of activations and filters are loaded from tensors in global memory
and stored to shared memory. Each thread within the threadblock loads one or more vectors and
collectively span the entire tile.
The following figure illustrates one particular iteration of the Implicit GEMM mainloop. Each
thread within the threadblock is mapped to several vectors of elements in the Activations and
Filters tensors. Each index in the GEMM _M_ dimension corresponds to a unique _(N,P,Q)_
index of the output tensor, and pointers may be computed based on this as well as
filter position _(r,s)_.

The CUTLASS component that embodies this functionality is [Conv2dFpropFilterTileAccessIteratorAnalytic](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h).
Its constructor computes the mapping of GEMM _M_ to _(N, P, Q)_, the `at()` method maps the linear offset into the Activations
tensor for each memory access the thread is to perform. Additionally, the method `valid()` computes the valided of the access
for each filter position and for each memory access to indicate whether the memory access will be within the bounds of the
tensor or out of bounds.
`operator++()` iterates over memory accesses performed by a thread in both contiguous and strided dimension.
Similar logic holds for [Conv2dFpropFilterTileAccessIteratorAnalytic](/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h).
To reduce computational overhead in the mainloop body, the pointer offsets may be precomputed
in host code and provided to the CUDA kernel as a lookup table in its `Params` structure.
As shown in [Conv2dFpropFilterTileAccessIteratorOptimized](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h),
the logic to compute offsets from filter position has been extracted to the `Params` constructor.
In the previous two sections, we have described how data may be loaded from activations and filters tensors
in global memory to compute convolution, and we have described a composition of `ldmatrix` and `mma.sync`
to fetch data from Shared Memory and issue Tensor Core operations.
To ensure this data movement is efficient, care must be taken to ensure bank conflicts are avoided. CUTLASS
uses a permuted Shared Memory layout to avoid bank conflicts when storing to Shared Memory and to efficiently
load from Shared Memory using `ldmatrix`. The following figure illustrates the thread mapping used for
the loading the activations and filters threadblock tiles from global memory and the permuted layout in
Shared Memory.

In the illustration, one warp-wide memory access is highlighted in blue, with individual threads
loading one 128-bit vector. The tile in global memory could correspond either to the activations
or filters and is assumed to be 'strip-mined' with four threads loading consecutive channels.
Shared Memory is visualized as a 'row-major' matrix with eight columns representing
the eight 128-bit banks.
As described in the CUTLASS GTC 2019 presentation [slides](https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9593-cutensor-high-performance-tensor-operations-in-cuda-v2.pdf),
[recording](https://developer.nvidia.com/gtc/2019/video/S9593), an access to Shared Memory will be conflict-free if
the following conditions are satisfied across each warp:
- {T0, T1, .., T7} do not access the same 128-bit bank
- {T8, T9, .., T16} do not access the same 128-bit bank
- {T16, T17, .., T23} do not access the same 128-bit bank
- {T24, T25, .., T31} do not access the same 128-bit bank
To achieve conflict-free stores, the Shared Memory layout remaps the strip-mined arrangement to transpose
the vectors and applies an XOR operation on the column index of each thread's pointer. Specifically,
```c++
int store_column = (lane_id % 8) ^ (lane_id / 8);
```
This transformation on the layout will be instrumental in reading slices of data from Shared Memory
to compute the warp-level matrix multiply using Tensor Cores.
The following figure shows how the first sixteen threads participating in an `ldmatrix` instruction
logically map to the c=0..31 slice of a matrix in Shared Memory. This slice is known as a "k-group"
within the code because it corresponds to the same K-index of a warp-level matrix multiply.

The lower half of the figure shows the physical arrangement in Shared Memory, with threads offset by row and column
according to the XOR function. By inspection, we can observe there are no bank conflicts, as _T0 ... T7_ each access unique
banks, as do _T8 ... T15_. and beyond.
To advance to the next "k-group" within Shared Memory, pointers are updated using an XOR operation according to
the following sequence:
- **^1** advances from _k=0_ to _k=1_
- **^3** advances from _k=1_ to _k=2_
- **^1** advances from _k=2_ to _k=3_
- **^3** advances from _k=3_ to _k=0_
The first of these transitions is shown below.

The [CUTLASS warp-level GEMM API](/media/docs/gemm_api.md#warp-level-matrix-multiply-api) defines templates for
loading slices of data from permuted Shared Memory and issuing operations to Tensor Cores.
### Updating the Output Tensor
After the mainloop terminates, the accumulator tile of the warp-level GEMM stores a warp's contribution to the output
tensor. However, the distribution of data among threads within the threadblock is specialized for efficient matrix multiply-accumulate
operations using Tensor Cores and is not conducive to efficient, coalesced operations to Global Memory. A data rearrangement is
needed.
The **Epilogue** is the component for exchanging accumulator elements through Shared Memory, loading slices of the output
matrix or tensor, applying an elementwise operation such as linear scaling or bias, and storing the result to the output tensor.
CUTLASS structures this as several components:
- [cutlass::epilogue::threadblock::Epilogue](/include/cutlass/epilogue/threadblock/epilogue.h) - the top-level component for looping over the entire threadblock tile
- [cutlass::epilogue::warp::TileIteratorTensorOp](/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h) - a specialized component for storing accumulators for Tensor Core to Shared Memory
- [cutlass::epilogue::threadblock::SharedLoadIterator](/include/cutlass/epilogue/threadblock/shared_load_iterator.h) - a component for loading elements from a row-major arrangement in Shared Memory
- [cutlass::epilogue::threadblock::PredicatedTileIterator](/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h) - a component for loading or storing matrix fragments to Global Memory (with bounds checks)
- [cutlass::epilogue::thread::LinearCombination](/include/cutlass/epilogue/thread/linear_combination.h) - an element-wise function computing `alpha * AB + beta * C` to compute the final output
## Unit Tests
Unit tests verify the functional behavior of each of the above components in a standalone CUDA kernel. This provides a
convenient environment to (a.) inspect the template definition, (b.) showcase instantiation of use of these templates
in device code, and (c.) assert functional correctness.
Before building the example, first perform the prerequisite steps for building any CUTLASS component [described here](/media/docs/quickstart.md).
Compute capability 7.5 refers to the Turing architecture, and this work requires CUDA 10.2 Toolkit or later to target
Turing Tensor Cores using the native `mma` [PTX instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-8832).
```bash
$ mkdir build && cd build
$ cmake .. -DCUTLASS_NVCC_ARCHS=75
```
To build the example, execute `make 09_turing_tensorop_conv2dfprop` from the build directory.
```bash
$ make 09_turing_tensorop_conv2dfprop
$ ls examples/09_turing_tensorop_conv2dfprop
examples/09_turing_tensorop_conv2dfprop
```
This example provides a simple command line interface to specify the extents of 4D tensors of 4-bit integer elements (`cutlass::int4b_t`),
initialize them to random values, and compute the result of a convolutional layer. Optionally, the input and output
tensors may be saved to .csv files, and the CUTLASS host-side reference check may be executed to verify correctness.
The complete usage statement is visible by running with `--help`: