267 lines
8.9 KiB
Markdown
267 lines
8.9 KiB
Markdown
# CuTe Layouts
|
|
|
|
## Layout
|
|
|
|
This document describes `Layout`, CuTe's core abstraction.
|
|
A `Layout` maps from a logical coordinate space
|
|
to an index space.
|
|
|
|
`Layout`s present a common interface to multidimensional array access
|
|
that abstracts away the details of how the array's elements are organized in memory.
|
|
This lets users write algorithms that access multidimensional arrays generically,
|
|
so that layouts can change, without users' code needing to change.
|
|
|
|
CuTe also provides an "algebra of `Layout`s."
|
|
`Layout`s can be combined and manipulated
|
|
to construct more complicated layouts
|
|
and to partition them across other layouts.
|
|
This can help users do things like partition layouts of data over layouts of threads.
|
|
|
|
## Layouts and Tensors
|
|
|
|
Any of the `Layout`s discussed in this section can be composed with data -- e.g., a pointer or an array -- to create a `Tensor`.
|
|
The `Layout`'s logical coordinate space represents the logical "shape" of the data,
|
|
e.g., the modes of the `Tensor` and their extents.
|
|
The `Layout` maps a logical coordinate into an index,
|
|
which is an offset to be used to index into the array of data.
|
|
|
|
For details on `Tensor`, please refer to the
|
|
[`Tensor` section of the tutorial](./03_tensor.md).
|
|
|
|
## Shapes and Strides
|
|
|
|
A `Layout` is a pair of `Shape` and `Stride`.
|
|
Both `Shape` and `Stride` are `IntTuple` types.
|
|
|
|
### IntTuple
|
|
|
|
An `IntTuple` is defined recursively as either a single integer, or a tuple of `IntTuple`s.
|
|
This means that `IntTuple`s can be arbitrarily nested.
|
|
Operations defined on `IntTuple`s include the following.
|
|
|
|
* `get<I>(IntTuple)`: The `I`th element of the `IntTuple`. For an `IntTuple` consisting of a single integer, `get<0>` is just that integer.
|
|
|
|
* `rank(IntTuple)`: The number of elements in an `IntTuple`. A single integer has rank 1, and a tuple has rank `tuple_size`.
|
|
|
|
* `depth(IntTuple)`: The number of hierarchical `IntTuple`s. A single integer has depth 0, a tuple of integers has depth 1, a tuple that contains a tuple of integers has depth 2, etc.
|
|
|
|
* `size(IntTuple)`: The product of all elements of the `IntTuple`.
|
|
|
|
We write `IntTuple`s with parenthesis to denote the hierarchy. For example, `6`, `(2)`, `(4,3)`, `(3,(6,2),8)` are all `IntTuple`s.
|
|
|
|
## Layout
|
|
|
|
A `Layout` is then a pair of `IntTuple`s. The first element defines the abstract *shape* of the `Layout`, and the second element defines the *strides*, which map from coordinates within the shape to the index space.
|
|
|
|
Since a `Layout` is just a pair of `IntTuple`s, we can define operations on `Layout`s analogous to those defined on `IntTuple`.
|
|
|
|
* `get<I>(Layout)`: The `I`th sub-layout of the `Layout`.
|
|
|
|
* `rank(Layout)`: The number of modes in a `Layout`.
|
|
|
|
* `depth(Layout)`: The number of hierarchical `Layout`s. A single integer has depth 0, a tuple of integers has depth 1, a tuple that contains a tuple of integers has depth 2, etc.
|
|
|
|
* `shape(Layout)`: The shape of the `Layout`.
|
|
|
|
* `stride(Layout)`: The stride of the `Layout`.
|
|
|
|
* `size(Layout)`: The logical extent of the `Layout`. Equivalent to `size(shape(Layout))`.
|
|
|
|
### Hierarchical access functions
|
|
|
|
`IntTuple`s and thus `Layout`s can be arbitrarily nested.
|
|
For convenience, we define versions of some of the above functions
|
|
that take a sequence of integers, instead of just one integer.
|
|
This makes it possible to access elements
|
|
inside of nested `IntTuple` or `Layout`.
|
|
For example, we permit `get<I...>(x)`, where `I...` here
|
|
and throughout this section is a "C++ parameter pack"
|
|
that denotes zero or more (integer) template arguments.
|
|
That is, `get<I0,I1,...,IN>(x)` is equivalent to
|
|
`get<IN>(` $\dots$ `(get<I1>(get<I0>(x)))` $\dots$ `))`,
|
|
where the ellipses are pseudocode and not actual C++ syntax.
|
|
These hierarchical access functions include the following.
|
|
|
|
* `rank<I...>(x) := rank(get<I...>(x))`. The rank of the `I...`th element of `x`.
|
|
|
|
* `depth<I...>(x) := depth(get<I...>(x))`. The depth of the `I...`th element of `x`.
|
|
|
|
* `size<I...>(x) := size(get<I...>(x))`. The size of the `I...`th element of `x`.
|
|
|
|
### Vector examples
|
|
|
|
We define a vector as any `Shape` and `Stride` pair with `rank == 1`.
|
|
For example, the `Layout`
|
|
|
|
```
|
|
Shape: (8)
|
|
Stride: (1)
|
|
```
|
|
|
|
defines a contiguous 8-element vector.
|
|
For a vector with the same Shape but a Stride of `(2)`,
|
|
the interpretation is that the eight elements
|
|
are stored at positions 0, 2, 4, $\dots$, 14.
|
|
|
|
By the above definition, we *also* interpret
|
|
|
|
```
|
|
Shape: ((4,2))
|
|
Stride: ((1,4))
|
|
```
|
|
|
|
as a vector, since its shape is rank 1. The inner shape describes a 4x2 layout of data in column-major order, but the extra pair of parenthesis suggest we can interpret those two modes as a single 1-D 8-element vector instead. Due to the strides, the elements are also contiguous.
|
|
|
|
### Matrix examples
|
|
|
|
Generalizing, we define a matrix as any `Shape` and `Stride` pair with rank 2. For example,
|
|
|
|
```
|
|
Shape: (4,2)
|
|
Stride: (1,4)
|
|
0 4
|
|
1 5
|
|
2 6
|
|
3 7
|
|
```
|
|
|
|
is a 4x2 column-major matrix, and
|
|
|
|
```
|
|
Shape: (4,2)
|
|
Stride: (2,1)
|
|
0 1
|
|
2 3
|
|
4 5
|
|
6 7
|
|
```
|
|
|
|
is a 4x2 row-major matrix.
|
|
|
|
Each of the modes of the matrix can also be split into *multi-indices* like the vector example.
|
|
This lets us express more layouts beyond just row major and column major. For example,
|
|
|
|
```
|
|
Shape: ((2,2),2)
|
|
Stride: ((4,1),2)
|
|
0 2
|
|
4 6
|
|
1 3
|
|
5 7
|
|
```
|
|
|
|
is also logically 4x2, with a stride of 2 across the rows but a multi-stride down the columns.
|
|
Since this layout is logically 4x2,
|
|
like the column-major and row-major examples above,
|
|
we can _still_ use 2-D coordinates to index into it.
|
|
|
|
## Constructing a `Layout`
|
|
|
|
A `Layout` can be constructed in many different ways.
|
|
It can include any combination of compile-time (static) integers
|
|
or run-time (dynamic) integers.
|
|
|
|
```c++
|
|
auto layout_8s = make_layout(Int<8>{});
|
|
auto layout_8d = make_layout(8);
|
|
|
|
auto layout_2sx4s = make_layout(make_shape(Int<2>{},Int<4>{}));
|
|
auto layout_2sx4d = make_layout(make_shape(Int<2>{},4));
|
|
|
|
auto layout_2x4 = make_layout(make_shape (2, make_shape (2,2)),
|
|
make_stride(4, make_stride(2,1)));
|
|
```
|
|
|
|
The `make_layout` function returns a `Layout`.
|
|
It deduces the returned `Layout`'s template arguments from the function's arguments.
|
|
Similarly, the `make_shape` and `make_stride` functions
|
|
return a `Shape` resp. `Stride`.
|
|
CuTe often uses these `make_*` functions,
|
|
because constructor template argument deduction (CTAD)
|
|
does not work for `cute::tuple` as it works for `std::tuple`.
|
|
|
|
## Using a `Layout`
|
|
|
|
The fundamental use of a `Layout` is to map between logical coordinate space(s) and an index space. For example, to print an arbitrary rank-2 layout, we can write the function
|
|
|
|
```c++
|
|
template <class Shape, class Stride>
|
|
void print2D(Layout<Shape,Stride> const& layout)
|
|
{
|
|
for (int m = 0; m < size<0>(layout); ++m) {
|
|
for (int n = 0; n < size<1>(layout); ++n) {
|
|
printf("%3d ", layout(m,n));
|
|
}
|
|
printf("\n");
|
|
}
|
|
}
|
|
```
|
|
|
|
which produces the following output for the above examples.
|
|
|
|
```
|
|
> print2D(layout_2sx4s)
|
|
0 2 4 6
|
|
1 3 5 7
|
|
> print2D(layout_2sx4d)
|
|
0 2 4 6
|
|
1 3 5 7
|
|
> print2D(layout_2x4)
|
|
0 2 1 3
|
|
4 6 5 7
|
|
```
|
|
|
|
The multi-indices within the `layout_2x4` example are handled as expected and interpreted as a rank-2 layout.
|
|
|
|
Note that for `layout_2x4`, we're using a 1-D coordinate for a 2-D multi-index in the second mode. In fact, we can generalize this and treat all of the above layouts as 1-D layouts. For instance, the following `print1D` function
|
|
|
|
```c++
|
|
template <class Shape, class Stride>
|
|
void print1D(Layout<Shape,Stride> const& layout)
|
|
{
|
|
for (int i = 0; i < size(layout); ++i) {
|
|
printf("%3d ", layout(i));
|
|
}
|
|
}
|
|
```
|
|
|
|
produces the following output for the above examples.
|
|
|
|
```
|
|
> print1D(layout_8s)
|
|
0 1 2 3 4 5 6 7
|
|
> print1D(layout_8d)
|
|
0 1 2 3 4 5 6 7
|
|
> print1D(layout_2sx4s)
|
|
0 1 2 3 4 5 6 7
|
|
> print1D(layout_2sx4d)
|
|
0 1 2 3 4 5 6 7
|
|
> print1D(layout_2x4)
|
|
0 4 2 6 1 5 3 7
|
|
```
|
|
|
|
This shows explicitly that all of the layouts are simply folded views of an 8-element array.
|
|
|
|
## Summary
|
|
|
|
* The `Shape` of a `Layout` defines its coordinate space(s).
|
|
|
|
* Every `Layout` has a 1-D coordinate space.
|
|
This can be used to iterate in a "generalized-column-major" order.
|
|
|
|
* Every `Layout` has a R-D coordinate space,
|
|
where R is the rank of the layout.
|
|
These spaces are ordered _colexicographically_
|
|
(reading right to left, instead of "lexicographically,"
|
|
which reads left to right).
|
|
The enumeration of that order
|
|
corresponds to the 1-D coordinates above.
|
|
|
|
* Every `Layout` has an h-D coordinate space where h is "hierarchical." These are ordered colexicographically and the enumeration of that order corresponds to the 1-D coordinates above. An h-D coordinate is congruent to the `Shape` so that each element of the coordinate has a corresponding element of the `Shape`.
|
|
|
|
* The `Stride` of a `Layout` maps coordinates to indices.
|
|
|
|
* In general, this could be any function from 1-D coordinates (integers) to indices (integers).
|
|
|
|
* In `CuTe` we use an inner product of the h-D coordinates with the `Stride` elements.
|