on either 2.x API implementing kernel layer GEMMs, or 3.x API
implementing kernel layer GEMMs (as detected by `gemm::detail::IsCutlass3GemmKernel`
discussed below). As a result, `GemmUniversalAdapter`'s behavior
might differ between the two specializations.
### Device API design differences
In CUTLASS 2.x, the Device API was more closely tied
to the Kernel API. In CUTLASS 3.0, the Device API
accepts any kernel type that meets the Kernel API
interface requirements. CUTLASS 3.0's Device API code is
parameterized by the kernel type, but this code
is *generic*; the same code works for any kernel type.
The device layer compatibility interface, `device::GemmUniversalAdapter`,
also provides reflective mappings from 3.0-specific types
back to the closest possible 2.x equivalent types. This is [discussed further in the section below](#conversions-between-2x-tags-and-30-types).
CUTLASS 3.0's `device::GemmUniversalAdapter` also exposes some new APIs that the 2.x `device::GemmUniversalAdapter` implementation does not. Most notably, this includes the ability to bypass the `GemmKernel::Arguments` to `GemmKernel::Params` lowering.
```c++
// Primary run() entry point API that is static allowing users to create and manage their own params.
In particular, the 2.x Device API specified the grid shape
used to launch the Kernel API.
In CUTLASS 3.0, the Kernel API controls its own grid shape,
while the device adapter simply queries the kernel with which it needs to be launched.
This change is required to support various kernel schedules
that may need their own schedule specific grid planning logic.
For example, persistent kernel schedules generally only launch with
as many threadblocks as the number of multiprocessors on the GPU.
All CUTLASS 3 `kernel::GemmUniversal` specializations expose the following (static) API:
```c++
// Returns true if the kernel can execute the provided GEMM arguments.
static bool
can_implement(Arguments const& args);
// Returns a dim3 representing the threadblock shape.
static constexpr dim3
get_block_shape();
// Returns a dim3 representing the grid shape in terms of threadblocks.
static constexpr dim3
get_grid_shape(Params const& params);
```
The device adapter simply queries the kernel for these three before launching it on the device.
CUTLASS 3.0 provides a meta-function to detect whether a `cutlass::gemm::kernel::*` implements
the 3.x API or 2.x API:
```c++
// include/cutlass/gemm/gemm.h
namespace cutlass:gemm::detail {
// The following metafunction is used to detect whether a
// `kernel::Gemm` or `kernel::GemmUniversal` implements the CUTLASS 3.x API,
// by checking whether the problem shape type is aliased within.
template <classGemmKernel,class =void>
struct IsCutlass3GemmKernel;
} // namespace cutlass:gemm::detail
```
Users can dispatch their generic code against 2.x and 3.x specializations with
this as a type trait for the kernel API version.
## Threadblock API and Inner Loops
Much of the CUTLASS 3 GEMM hierarchy for mainloops and inner loops diverges
from that of CUTLASS 2.x. With that also comes the introduction of the
`cutlass::gemm::collective` layer as a direct replacement and a superset
of the 2.x `cutlass::gemm::threadblock` layer. Going forward,
CUTLASS 3.x will discontinue new developments in the following namespaces.
*`cutlass::*::threadblock::*`
*`cutlass::*::warp::*`
*`cutlass::gemm::thread::*`
*`cutlass::arch::*` (except `barrier.h`)
`cutlass::gemm::collective`s are a superset of the threadblock layer where
all new mainloops will be developed. Users should look to the `CollectiveMma` type
if they wish to author custom mainloop code in the 3.x API.
Similarly, for the GEMM inner loops, `cute::MMA_Atom`s replace the
`gemm::warp` and `gemm::thread` layer code. Going forward, all new PTX instructions
and associated metadata development will occur directly inside [`cute/arch/*.hpp`](/include/cute/arch/) and [`cute/atom/*.hpp`](/include/cute/atom/).
The desired inner loop MMA iteration order and tiling can be achieved through careful
selection of the atom layout, value layout, and permutations of the `cute::TiledMma`.
For epilogues, the `cutlass::epilogue::collective` layer replaces `cutlass::threadblock::collective`. However, the thread-level epilogue elementwise operations
in `cutlass::epilogue::thread` will continue to be used in 3.x kernels as well, albeit, with
shows how to use 2.x epilogue thread operators with 3.0 API kernels.
## Porting from 2.x to 3.0 API
### CUTLASS 2.x layout tags and CUTLASS 3.0 major modes
CUTLASS 2.x and CUTLASS 3.0 use both
different wording and different types
to describe the permitted layouts
of GEMM's input matrices A and B.
CUTLASS 3.0 does not use the terms "column major"
or "row major" to describe matrix layouts.
Starting with CUTLASS 3.0, adoption of CuTe allows us to decouple
* the coordinate mode order (logical shape) of layouts from
* the index space stride order of the backing storage.
In line with our switch to a conceptual GEMM hierarchy, we view the major modes not from a BLAS-3 perspective.
Rather, we divide the modes into two categories.
* "Inner modes" or "K-modes" are contracted over during the GEMM.
Therefore, they are not present in the output tensor.
* "Outer modes" or "MN-modes" are preserved in the output.
Now, instead of `RowMajor` or `ColumnMajor`, whose major stride depends on whether we are referring to the
A or the B matrix, we uniformly employ the "K major" or "MN major" terminology and enforce the convention of all tensors having the shape `[M/N, K, L]` regardless of which mode is major. That is,
* the input matrix A has shape M x K,
* the input matrix B has shape N x K, and
* the input/output matrices C/D have shape M x N.
Note that this convention for B
differs from the BLAS's GEMM interface,
which specifies that B has shape K x N.
CUTLASS 3.0 uses these names of the modes
to specify which mode of a matrix has stride 1.
For the matrix A,
* "M major" means that the matrix is stride 1
in the M mode, and
* "K major" means that the matrix is stride 1
in the K mode.
For the matrix B,
* "N major" means that the matrix is stride 1
in the N mode (which for B is mode 0,
because the convention is that B is N x K); and
* "K major" means that the matrix is stride 1
in the K mode (which for B is mode 1).
CUTLASS 2.x defines "layout tag" classes
`cutlass::layout::ColumnMajor` and `cutlass::layout::RowMajor`,