The simplest of the tutorial examples covers the basics of partitioning the global memory into tiles across the CTAs (also called threadblocks in CUDA), partitioning the data tiles across the threads of each CTA, and writing a mainloop using `cute::copy` and `cute::gemm`.
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
```
The appropriate modes of the `Shape` are selected to construct each of the tensors. The preconditions make sure that for every integer in the `Shape` there is a corresponding integer in the associated `Stride`.
Note that the comment after B says `(N,K)` rather than `(K,N)`.
This means that B is treated as an NxK matrix instead of a KxN matrix as is typical within BLAS and most other matrix-matrix multiplications.
CuTe follows the convention that the semantics of matrix modes is
`(M,K)` for `A`, `(N,K)` for `B`, and `(M,N)` for `C`, which we try to record in comments everywhere.
For each of the `(M,K)`, `(N,K)`, and `(M,N)` tensors, the `gemm_nt` and `gemm_tn` construct the strides those tensors will use. In `gemm_nt` the strides are defined as
We've found that the BLAS convention of using "non-transposed" (N) and "transposed" (T) flags in conjunction with the mode conventions of `MxK * KxN` to confuse the core issue of "what layout does this matrix use" and "in which mode does my matrix have a stride-1?". Indeed, the answer to those questions can always be found by inspecting the CuTe `Layout`.
Instead of row-major or column-major (or Transposed
and Not-Transposed), we have found it much more convenient to say that a matrix is "M-major" if it is stride-1 in the M-mode, "N-major" if it is stride-1 in the N-mode, or "K-major" if it is stride-1 in the K-mode. Furthermore, knowing that matrix multiply always performs a reduction in the K-mode, it is very convenient from a software perspective to always have the K-mode in the same place and adopt the mode convention `MxK * NxK`. Implementations will always reduce over the second mode (the K mode) of both input matrices and leads to cases where implementations can treat both input matrices the same way.
At the highest level, the work is distributed across CTAs. In principle, each CTA's tile could come from the input tensors in many different ways. Many [CuTe `Tiler`s](./02_layout_algebra.md#composition-tilers) could be used to tile the data, but for these cases it is sufficient to simply use the shape of the desired CTA tile.
```cpp
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int<8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
* The `m`-coordinate of this tile is given by `blockIdx.x`.
* The `n`-coordinate of this tile is given by `blockIdx.y`.
* The `k`-coordinate of this tile is unspecified -- we want all of the tiles in `K` so the coordinate is `_`, the `Underscore` value, to keep that mode.
2. apply the coord to the second mode, the "Rest" mode, to extract out the correct tiles for this CTA.
```cpp
// (BLK_M,BLK_N,k)
Tensor gA = gA_mk(make_coord(_,_), select<0,2>(cta_coord));
```
Because the projections of the tiler and coord are symmetric and the two steps (apply a tiler and then slice into the rest-mode to produce a partition) are so common, they are wrapped together into the projective `local_tile` interface.
For tensor `A`, we are left with a rank-3 tensor of shape `(BLK_M,BLK_K,k)`. The first two modes are precisely the modes of the CTA tile and the last mode indexes over all of the tiles that will be reduced by this CTA. In the mainloop section below, this mode is iterated over via the `k_tile` loop.
The shared memory layouts that are used to hold the tiles of data for A and B are also passed in as the parameters `ASmemLayout sA_layout` and `BSmemLayout sB_layout`.
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
```
which produces simple M-major and N-major layouts. In `gemm_tn` these are
```cpp
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major
auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major
```
which produces simple K-major layouts.
As is evident, these smem layouts can be almost anything. Inside the kernel, they are checked for only two properties: the shared memory layouts are static and they are the same top-level shape as the `CtaTiler`.
* Static layouts let us statically allocate shared memory as shown below.
* Static layouts are often more efficient and allow CuTe to dispatch to optimized implementations.
* Static layouts makes it easier to prove correctness of the algorithm and provide checks like the above -- the smem layout sizes are the same as the CTA tile sizes.
As stated, the shared memory layouts can be anything that satisfy those conditions. Optimizing kernels like these is often performed by finding a good shared memory layout that provides good access patterns for both the writes to and the reads from shared memory. This includes the ability to vectorize reads and writes as well as avoid shared memory bank conflicts.
Note how the shared memory allocation depends only on the data type and the layout. What's a `cosize`? Because a `Layout` is a function, we can speak of its domain and codomain. The `size` of a layout is the size of its domain and the `cosize` of a layout is the size of its codomain. If we want to allocate an array for which all the offsets produced by a layout are valid, then we can use the `cosize` of the layout as the length of the array (in units of elements).
The kernel now has tiles of global memory by applying the `CtaTiler` to the full tensors and it also has tiles of shared memory by allocating appropriately. We now want to create an efficient way to copy one tile of global memory to our tile of shared memory. A trivial way to do this would be to use a single thread and copy each element.
```cpp
if (thread0()) {
Tensor gA0 = gA(_,_,0); // (BLK_M,BLK_K), the 0th tile
If we partition the two tiles of data across the threads in the CTA, then each thread can copy its own subtensor of data. There are lots of ways this partitioning could occur, however.
Both cases happen to use 32x8 threads, which will be used to partition a 128x8 tile of gmem and smem data into a 4x1 subtensor for each thread. The only difference here is that `gemm_nt` uses M-major and N-major threads to match the order of data in global memory and `gemm_tn` uses K-major threads to match the order of data in global memory.
where `local_partition` is a lot like `local_tile`, except the coordinate slices into the tile-mode (the first mode) of the `zipped_divide` rather than the rest-mode (the second mode). That is, each thread gets one element of data assigned to it per thread tile and that thread tile is repeated to cover the entire data tile.
The naming convention `tAsA` is pretty typical across CuTe and CUTLASS. This is read as "Partitioning pattern `tA` applied to tensor `sA`". In the next section, we'll see a different partitioner applied to `sA` to produce `tCsA`. By applying the same partitioning pattern, `tA`, to tensors `sA` and `gA`, we preserve the *logical consistency* of those tensors (checked by the assertions above) where logical elements between the two tensors correspond despite any differences in their data layouts. When used in `cute::copy`, for example, this naming convention let's us lexically verify that the two tensors are using the same partitioning pattern.
because every thread owns a different subtensor of the tile that will be copied.
### Math partitioning
The kernel now has tiles of shared memory copied in from global memory. We now want to create an efficient way to compute and accumulate the matrix product on that tile of shared memory. A trivial way to do this would be to use a single thread and compute directly.
If we partition the output tile `gC` across the threads in the CTA, then each thread can compute its own subtensor. There are lots of ways this partitioning could occur, however.
This is a m-major 16x16 layout of threads which will be used to partition a 128x128 tile of `C`-data, resulting in each thread computing its own 8x8 subtensor of `gC`.
where we've used the same projection-style interface to avoid applying the `N`-mode of `tC` to the `(BLK_M,BLK_K)` shape of `sA` and avoid applying the `M`-mode of `tC` to the `(BLK_N,BLK_K)` shape of `sB`.
This diagram shows a `tC` layout, highlights two threads in green and blue, shows the projections of the `tC` layout, and finally highlights the subtensors within `sA`, `sB`, and `gC` that `tCsA`, `tCsB`, and `tCgC` represent.
The mainloop iterates over tiles of global memory, reads those tiles into shared memory, and then performs the matrix-multiply and accumulates into the accumulators.
We can see that `k_tile` iterates over each tile of data, the `cute::copy` is performed for the current `k_tile` using the `tA` and `tB` thread-partitioned tensors, and the `cute::gemm` is computed for that current `k_tile` using the `tC` thread-partitioned tensors. Synchronization is provided so that this kernel works on any architecture.
An example that uses more complex `TiledMMA` and `TiledCopy` to perform partitioning in place of the `tA`, `tB`, and `tC` thread layouts. With this example, we try to emphasize that the shared memory layouts, the partitioning patterns, and the PTX instruction to use in each stage can be specified independently.
First, we can replace the `tA` partitioning and `tB` partitioning with `TiledCopy` partitioning, which provides for more complex partitioning patterns and checked dispatch to specific copy instructions.
On the left is the source-tensor partitioning and on the right is the destination-tensor partitioning. The partition patterns are the same for this case, but there exist PTX instructions which require different patterns in the source and destination. The diagram shows that each thread reads 4x1 `TA` elements and there are 32x8 threads. The `UniversalCopy<uint128_t>` forces the instruction to use a 128-bit copy instruction. If the partition (of `sA` or `gA` in this case) does not result in 4 `TA` elements that cannot be vectorized to a 128-bit load/store, then CuTe will statically fail with an error message to that effect.
which applies the source-tensor partitioning to `gA` via `partition_S` and applies the destination-tensor partitioning to `sA` via `partition_D`. The first mode, `CPY`, of the result tensors hold all of the elements that a single instruction will consume. In this case, that mode should have size-4 since there are four `TA=float` elements in a single 128-bit `uint128_t`.
Next, we can replace the `tC` partitioning with `TiledMMA` partitioning, which provides for more complex partitioning patterns and checked dispatch to specific MMA instructions.
On the left is the A-tensor partitioning, on the top is the B-tensor partitioning, and in the middle is the C-tensor partitioning.Because the `UniversalFMA` is a 1x1x1 MMA instruction, a 16x16x1 tiling of them results in a 16x16x1 `TiledMMA`. Other MMA instructions will have different threads involved and have different instruction sizes. In this case, all threads will read a single element from `A`, `B`, and `C` each.
which applies the A-tensor partitioning to `sA` via `partition_A`, applies the B-tensor partitioning to `sB` via `partition_B`, and applies the C-tensor partitioning to `gC` via `partition_C`. The first mode, `MMA`, of the result tensors hold all of the elements that a single instruction will consume. In this case, that mode should have size-1 since `UniversalFMA` is a 1x1x1 MMA, but in general the size of the first mode can vary and not even be the same across `tCsA`, `tCsB`, and `tCgC` depending on the MMA.
which produces M-major and N-major layouts, but they are padded to avoid shared memory bank conflicts. This simply improves the access pattern to and from shared memory and no other changes in the kernel are required.
An example that uses an optimized mainloop for Ampere SM80 architectures that explicitly pipelines shared memory using asynchronous reads from global memory.
our existing GEMM implementation. Included below is a launcher like `gemm_nt` that uses the same device kernel contained in `sgemm_1.cu` to compute a GETT with two m-modes.
```cpp
// Setup params for a GETT with two m-modes.
// The A and C tensors are assumed to be m0-major.
// Calls sgemm_1.cu's gemm_device<<<>>> without modification.
template <classTA,classTB,classTC,
class Alpha, class Beta>
void
gett(int m0, int m1, int n, int k,
Alpha alpha,
TA const* A, int ldAm1, int ldAk, // m0-major
TB const* B, int ldBk,
Beta beta,
TC * C, int ldCm1, int ldCn, // m0-major
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = make_shape(m0, m1); // (m0,m1)-multimode M
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(make_stride(Int<1>{}, ldAm1), ldAk); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(make_stride(Int<1>{}, ldCm1), ldCn); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Shape<_64,_2>{}; // Take _64 elements from m0 and _2 elements from m1
auto bN = Int<128>{};
auto bK = Int<8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int<8>{})); // (m,k) -> thr_idx
auto tB = make_layout(make_shape(Int<32>{}, Int<8>{})); // (n,k) -> thr_idx
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx
dim3 dimBlock(size(tC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid,dimBlock,0,stream>>>
(prob_shape, cta_tiler,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
alpha, beta);
}
```
Note that the only changes are the definition of shape `M`, the definition of strides `dA` and `dC`, and the definition of the CTA Tiler `bM`. The above uses a multimodel problem shape `M = (m0,m1)` and a multimodal CTA Tiler `bM = <_64,_2>` to change which portion of the global memory tensors `A` and `C` each CTA will be responsible for computing.
Similar examples can be found for CUTLASS 3.x kernels that are based on CuTe, such as [this Hopper GETT example](../../../examples/51_hopper_gett).