cutlass/cutlass/tile_iterator.h
Andrew Kerr 877bdcace6
Cutlass 1.3 Release (#42)
CUTLASS 1.3 Release
- Efficient GEMM kernel targeting Volta Tensor Cores via mma.sync instruction added in CUDA 10.1.
2019-03-20 10:49:17 -07:00

1292 lines
42 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines the Tile Traits concept and iterators for loading and storing to tiles
efficiently.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/fragment.h"
#include "cutlass/load_store.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/vector.h"
#include <cstdio>
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup tile_traits_concept Tile Traits Concept
@{
\ref tile_traits_concept is a type definining the shape of a tile and the distribution of accesses
by individual entities, either threads or other.
@par Tile Traits Concept
Types satisfying \ref tile_traits_concept define the following members
- <b>Tile</b> - a type satisfying \ref layout_concept describing the dimensions of the tile
- <b>Delta</b> - a type satisfying \ref layout_concept describing the increments between accesses
along each dimension
- <b>Iterations</b> - a type satisfying \ref layout_concept describing the number of accesses
along each dimension
- <b>Offset</b> - the type of a <i>functor</i> computing the offset of each participating entity
as a Coord<4>.
@}
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Specifies dimension in which post-increment accesses advance
struct IteratorAdvance {
enum Kind { kD, kH, kW };
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief A template defining \ref tile_traits_concept
* @concept{tile_traits_concept}
*/
template <typename Tile_,
typename Delta_,
typename Iterations_,
typename ThreadOffset_,
int AccessSize>
struct TileTraits {
/// Shape of the tile
typedef Tile_ Tile;
/// Number of steps between accesses along each dimension
typedef Delta_ Delta;
/// Number of accesses performed
typedef Iterations_ Iterations;
/// Functor that returns the logical coordinate of each entity's initial offset in the tile
//
// ThreadOffset should be a functor defined like:
//
// struct ThreadOffsetExample {
// CUTLASS_DEVICE
// Coord<4> operator()() const {
// return make_Coord(0, threadIdx.y, threadIdx.x, 0);
// }
// };
//
typedef ThreadOffset_ ThreadOffset;
/// Strides for immediate offset computation
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
/// Access size
static int const kAccessSize = AccessSize;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Functor computing a predicate given the logical position of an access
template <typename Delta_>
struct RegularTilePredicateFunctor {
typedef Delta_ Delta;
/// Dimensions of the bounding volume
Coord<3> bounds;
/// Constructs a predicate functor given the bounds of a tensor
CUTLASS_HOST_DEVICE
RegularTilePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {}
/// Computes the predicate given the logical position of an access
CUTLASS_HOST_DEVICE
bool operator()(Coord<3> iteration, Coord<3> offset) const {
return (iteration[0] * Delta::kD + offset[0] < bounds[0]) &&
(iteration[1] * Delta::kH + offset[1] < bounds[1]) &&
(iteration[2] * Delta::kW + offset[2] < bounds[2]);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
struct DumpType {};
/// Iterator for accessing a stripmined tile in memory
template <typename Traits_,
typename Scalar_,
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
typename Index_ = int,
typename FragmentElement_ = Scalar_,
FragmentElementType::Kind FragmentElementType_ = FragmentElementType::kScalar,
typename Skew_ = Shape<0, 0, 0, 0> >
struct TileIteratorBase {
/// concept TileTraits
typedef Traits_ Traits;
/// Scalar element
typedef Scalar_ Scalar;
/// Fragment element
typedef FragmentElement_ FragmentElement;
/// Specifies dimension in which post-increment accesses advance.
static IteratorAdvance::Kind const kAdvance = Advance_;
/// Specifies iterator storage fragment type (Scalar or WmmaMatrix)
static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_;
/// Source or destination memory space
static MemorySpace::Kind const kMemorySpace = MemorySpace;
/// Index type
typedef Index_ Index;
/// Long index
typedef long long LongIndex;
/// Skew quantity
typedef Skew_ Skew;
/// Tile shape
typedef typename Traits::Tile Tile;
/// Distance along each dimension
typedef typename Traits::Delta Delta;
/// The strides in each dimension between different loads/stores.
typedef typename Traits::ImmediateOffsetStrides ImmediateOffsetStrides;
/// Iterations
typedef typename Traits::Iterations Iterations;
/// Thread offset
typedef typename Traits::ThreadOffset ThreadOffset;
/// The number of scalars accessed per load/store.
static int const kAccessSize = Traits::kAccessSize;
/// The elements loaded/store by one instruction.
typedef typename Vectorize<FragmentElement, kAccessSize>::Type AccessType;
/// The size of storage needed per fragment
static int const kFragmentSize =
(kFragmentElementType == FragmentElementType::kWmmaMatrix ? 16 : sizeof(AccessType));
/// The storage.
typedef Fragment<Scalar, ShapeCount<Tile>::kCount, kFragmentSize> Storage;
/// The fragment.
typedef Fragment<FragmentElement, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
/// The shape of the fragment.
typedef typename FragmentIterator::FragmentShape FragmentShape;
/// Default predicate mask type
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
//
// Params struct
//
/// Parameters to the iterator
struct Params {
//
// Dat members
//
Index stride_d;
Index stride_h;
Index stride_w;
Index inc_d;
Index inc_h;
Index inc_w;
Index inc_advance;
//
// Methods
//
/// Constructs params
CUTLASS_HOST_DEVICE
Params() : stride_d(0), stride_h(0), stride_w(0), inc_d(0), inc_h(0), inc_w(0) {}
/// Constructs params
CUTLASS_HOST_DEVICE
Params(Index _stride_d,
Index _stride_h,
Index _stride_w,
Index _inc_d,
Index _inc_h,
Index _inc_w,
Index _inc_advance)
: stride_d(_stride_d),
stride_h(_stride_h),
stride_w(_stride_w),
inc_d(_inc_d),
inc_h(_inc_h),
inc_w(_inc_w),
inc_advance(_inc_advance) {}
/// Constructs params with a stride vector
CUTLASS_HOST_DEVICE
Params(Coord<4> const &stride) {
initialize(stride);
}
/// Initializes params
CUTLASS_HOST_DEVICE
int initialize(Index _stride_d,
Index _stride_h,
Index _stride_w,
Index _inc_d,
Index _inc_h,
Index _inc_w,
Index _inc_advance) {
stride_d = _stride_d;
stride_h = _stride_h;
stride_w = _stride_w;
inc_d = _inc_d;
inc_h = _inc_h;
inc_w = _inc_w;
inc_advance = _inc_advance;
return 0;
}
/// Initializes the parameters object from a vector of strides
CUTLASS_HOST_DEVICE
int initialize(Coord<4> const &stride) {
return initialize(stride[0], stride[1], stride[2]);
}
/// Initializes the parameters object from a vector of strides
CUTLASS_HOST_DEVICE
int initialize(Index _stride_d, Index _stride_h, Index _stride_w) {
stride_d = _stride_d;
stride_h = _stride_h;
stride_w = _stride_w;
inc_w = stride_w * Delta::kW;
inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
inc_d = stride_h * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) -
stride_w * Delta::kW * (Iterations::kW - 1);
inc_advance = 0;
if (kAdvance == IteratorAdvance::kH) {
// Advance in the H dimension.
inc_advance = Tile::kH * stride_h;
} else if (kAdvance == IteratorAdvance::kW) {
// Advance in the W dimension.
inc_advance = Tile::kW * stride_w;
} else {
// Advance in the D dimension.
inc_advance = Tile::kD * stride_d;
}
inc_advance -= stride_h * Delta::kD * (Iterations::kD - 1) +
stride_h * Delta::kH * (Iterations::kH - 1) +
stride_w * Delta::kW * (Iterations::kW - 1);
return 0;
}
/// Gotta have this
CUTLASS_HOST_DEVICE int initialize() {
stride_d = 0;
stride_h = 0;
stride_w = 1;
inc_advance = 0;
inc_d = inc_h = inc_w = 0;
return 0;
}
};
/// Is the iterator valid?
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
//
// Static function members
//
/// Initializes a predicate vector
template <typename PredicateIterator, typename PredicateFunctor>
CUTLASS_HOST_DEVICE static void initialize_predicates(PredicateIterator predicate_it,
PredicateFunctor const &predicate_func,
Coord<3> const &offset) {
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
bool enable = predicate_func(make_Coord(d, h, w), offset);
predicate_it.set(enable);
++predicate_it;
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup tile_load_iterator_concept Tile Load Iterator Concept
@{
\ref tile_load_iterator_concept enables loading a tile from addressable memory into a fragment
@par Tile Load Iterator Concept
Types satisfying \ref tile_load_iterator_concept define the following members
- <b>PredicateVector</b> - a \ref predicate_vector_concept with sufficient predicate storage for
each access implied by the tile traits
- <b>Fragment</b> - the destination fragment type satisfying \ref fragment_concept
- <b>initialize_predicates(pred_it, bounds, block_offset)</b> - function initializing a predicate
vector according to externally specified bounds
- <b>load_post_increment(fragment, pred_it)</b> - a method that loads a fragment and increments
the iterator to the next tile, guarded by a \ref predicate_iterator_concept
- <b>load_post_increment(fragment)</b> - a method that loads a fragment and increments the
iterator to the next tile
- <b>load(fragment, pred_it)</b> - a const method that loads a fragment, guarded by a \ref
predicate_iterator_concept
- <b>load(fragment)</b> - a method that loads a fragment
@}
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief An iterator implementing \ref tile_load_iterator_concept for loading a tile from memory
* @concept{tile_load_iterator_concept}
*/
template <typename Traits_,
typename Scalar_,
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
typename Index_ = int,
typename FragmentElement_ = Scalar_,
FragmentElementType::Kind FragmentElementType_ = FragmentElementType::kScalar,
typename Skew_ = Shape<0, 0, 0, 0> >
struct TileLoadIterator : public TileIteratorBase<Traits_,
Scalar_,
Advance_,
MemorySpace,
Index_,
FragmentElement_,
FragmentElementType_,
Skew_> {
/// Base class
typedef TileIteratorBase<Traits_,
Scalar_,
Advance_,
MemorySpace,
Index_,
FragmentElement_,
FragmentElementType_,
Skew_>
Base;
/// concept TileTraits
typedef typename Base::Traits Traits;
/// Scalar element
typedef typename Base::Scalar Scalar;
/// Fragment element
typedef FragmentElement_ FragmentElement;
/// Specifies in which dimension post-increment accesses advance.
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
/// Specifies type of iterator fragment storage (Salar or WmmaMatrix)
static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_;
/// Source or destination memory space
static MemorySpace::Kind const kMemorySpace = Base::kMemorySpace;
/// Index type
typedef typename Base::Index Index;
/// Index type
typedef typename Base::LongIndex LongIndex;
/// Skew quantity
typedef typename Base::Skew Skew;
/// Tile shape
typedef typename Base::Tile Tile;
/// Delta
typedef typename Base::Delta Delta;
/// Iterations
typedef typename Base::Iterations Iterations;
/// ThreadOffset functor
typedef typename Base::ThreadOffset ThreadOffset;
/// Fragment type
typedef typename Base::FragmentShape FragmentShape;
/// Memory access type
typedef typename Base::AccessType AccessType;
/// The number of scalars accessed per load/store.
static int const kAccessSize = Base::kAccessSize;
/// Fragment definition
typedef typename Base::Fragment Fragment;
/// Fragment iterator definition
typedef typename Base::FragmentIterator FragmentIterator;
/// Fragment const iterator definition
typedef typename Base::FragmentConstIterator FragmentConstIterator;
/// Default predicate mask type
typedef typename Base::PredicateVector PredicateVector;
/// Storage object that may be loaded from
typedef typename Base::Storage SharedStorage;
/// IteratorBase parameters
typedef typename Base::Params BaseParams;
/// Do we require a fence?
enum { kRequiresLoadFence = Tile::kD == 1 };
/// The pointer type
typedef Scalar const *Pointer;
/// Tensor reference for the load iterator
typedef TensorRef<Scalar const, 4> TensorRef;
/// Parameters
struct Params : public BaseParams {
/// Pointer to memory
Scalar const *pointer;
//
// Methods
//
/// Initialize params to access storage object
CUTLASS_HOST_DEVICE
Params() : pointer(0){ Base::Params::initialize(); }
/// Initialize params to access storage object
CUTLASS_HOST_DEVICE
Params(Scalar const *ptr) : pointer(ptr) { Base::Params::initialize(); }
/// Constructs with a CompactTensorRef<>
CUTLASS_HOST_DEVICE
Params(TensorRef const &ref): pointer(ref.data()) {
Base::Params::initialize(ref.stride());
}
/// Initialize params to access storage object
CUTLASS_HOST_DEVICE
Params(Scalar const *ptr,
Index _stride_d,
Index _stride_h,
Index _stride_w,
Index _inc_d,
Index _inc_h,
Index _inc_w,
Index _inc_advance)
: pointer(ptr) {
Base::Params::initialize(
_stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
}
/// Initialize params to access storage object
CUTLASS_HOST_DEVICE
Params(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w)
: pointer(ptr) {
Base::Params::initialize(stride_d, stride_h, stride_w);
}
/// Initializes params to access a raw pointer
CUTLASS_HOST_DEVICE
int initialize(TensorRef const &ref) {
pointer = ref.data();
return Base::Params::initialize(ref.stride());
}
/// Initialize params to access storage object
CUTLASS_HOST_DEVICE
int initialize(SharedStorage const &storage) {
pointer = &storage[0];
Base::Params::initialize();
return 0;
}
/// Initialize params to access storage object
CUTLASS_HOST_DEVICE
int initialize(Scalar const *ptr) {
pointer = ptr;
Base::Params::initialize();
return 0;
}
/// Initializes params to access a raw pointer
CUTLASS_HOST_DEVICE
int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w) {
Base::Params::initialize(stride_d, stride_h, stride_w);
pointer = ptr;
return 0;
}
/// Initializes params
CUTLASS_HOST_DEVICE
int initialize(Scalar const *ptr,
Index _stride_d,
Index _stride_h,
Index _stride_w,
Index _inc_d,
Index _inc_h,
Index _inc_w,
Index _inc_advance) {
pointer = ptr;
Base::Params::initialize(
_stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
return 0;
}
// Initializes params to default values
CUTLASS_HOST_DEVICE
int initialize() { return Base::Params::initialize(); }
};
//
// Data members
//
/// Parameters structure
Params params;
/// Offset of an individual lane from the start of the tile
Coord<4> thread_offset;
/// Stage argument enables wrapping after some number of tiles have been loaded.
int stage;
//
// Predicate initialization
//
/// Initializes a predicate vector using a RegularTilePredicateFunctor
template <
/// Predicate iterator
typename PredicateIterator>
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
Coord<3> const &bounds,
Coord<3> const &block_offset = make_Coord(0,
0,
0)) {
Base::initialize_predicates(
predicate_it,
RegularTilePredicateFunctor<typename Traits::Delta>(bounds),
block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
}
/// Initializes a predicate vector using an arbitrary predicate functor
template <
/// Predicate iterator
typename PredicateIterator,
/// Functor computing predicates
typename PredicateFunctor>
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
PredicateFunctor const &functor,
Coord<3> const &block_offset) {
Base::initialize_predicates(
predicate_it,
functor,
block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
}
//
// Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE
TileLoadIterator() {}
/// Constructs a tile load iterator
CUTLASS_HOST_DEVICE
TileLoadIterator(Params const &_params,
Coord<3> const &block_offset = make_Coord(0, 0, 0),
ThreadOffset thread_offset_func = ThreadOffset())
: params(_params), stage(0) {
thread_offset = thread_offset_func();
Index pointer_offset = Index((block_offset[0] + thread_offset[0]) * params.stride_d) +
Index((block_offset[1] + thread_offset[1]) * params.stride_h) +
Index((block_offset[2] + thread_offset[2]) * params.stride_w);
params.pointer += pointer_offset;
}
/// Constructs a tile load iterator
CUTLASS_HOST_DEVICE
TileLoadIterator(Params const &,
Scalar const *ptr,
Coord<3> const &block_offset = make_Coord(0, 0, 0),
ThreadOffset thread_offset_func = ThreadOffset())
: stage(0) {
params.pointer = ptr + thread_offset_func()[2];
params.stride_d = 0;
params.stride_h = 0;
params.stride_w = 1;
params.inc_d = params.inc_h = params.inc_w = params.inc_advance = 0;
}
/// Increment in the D dimension
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
/// Increment in the H dimension
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
/// Increment in the W dimension
CUTLASS_HOST_DEVICE void inc_w() { params.pointer += params.inc_w; }
/// Increment in the next dimension
CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
/// Loads a single fragment element from memory
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const {
int const offset =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
Load<Scalar,
kAccessSize,
kMemorySpace,
kFragmentElementType,
FragmentElement,
Tile::kW,
sizeof(FragmentElement) * kAccessSize>::load(value, params.pointer, offset);
}
/// Increment the stage.
CUTLASS_HOST_DEVICE void inc_stage() {
if (Tile::kD > 1) {
int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
if (stage == Tile::kD - 1) {
params.pointer -= (Tile::kD - 1) * kStageSize;
stage = 0;
} else {
params.pointer += kStageSize;
stage = stage + 1;
}
}
}
/// Adds a vector offset to the iterator
CUTLASS_HOST_DEVICE TileLoadIterator & operator+=(Coord<3> const &offset) {
long long _offset = offset.template dot<long long>(
make_Coord(params.stride_d, params.stride_h, params.stride_w)
);
params.pointer += _offset;
return *this;
}
/// Adds a raw offset to the pointer
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset) { params.pointer += offset; }
CUTLASS_HOST_DEVICE Index stride_advance(void) {
Index stride = params.stride_h;
if (kAdvance == IteratorAdvance::kW) {
stride = params.stride_w;
}
return stride;
}
/// Loads a fragment and advances the iterator to the next tile.
template <typename Fragment, typename PredicateIterator>
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
if (*pred_it) {
load_element(
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
}
}
if (w < Iterations::kW - 1) {
inc_w();
}
}
if (h < Iterations::kH - 1) {
inc_h();
}
}
if (d < Iterations::kD - 1) {
inc_d();
}
}
inc_advance();
}
/// Loads a fragment and advances the iterator to the next tile.
template <typename Fragment>
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment) {
typename PredicateVector::TrivialIterator pred_it;
load_post_increment(fragment, pred_it);
}
/// Loads a fragment without advancing the iterator..
template <typename Fragment, typename PredicateIterator>
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
TileLoadIterator _load_it(*this);
_load_it.load_post_increment(fragment, pred_it);
}
/// Loads a fragment without advancing the iterator..
template <typename Fragment>
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
typename PredicateVector::TrivialIterator pred_it;
load(fragment, pred_it);
}
/// Loads a fragment without advancing the iterator..
template <typename Fragment>
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup tile_store_iterator_concept Tile Store Iterator Concept
@{
\ref tile_store_iterator_concept enables storing a tile to addressable memory
@par Tile Store Iterator Concept
Types satisfying \ref tile_load_iterator_concept define the following members
- <b>PredicateVector</b> - a \ref predicate_vector_concept with sufficient predicate storage for
each access implied by the tile traits
- <b>Fragment</b> - the destination fragment type satisfying \ref fragment_concept
- <b>initialize_predicates(pred_it, bounds, block_offset)</b> - function initializing a predicate
vector according to externally specified bounds
- <b>store_post_increment(fragment, pred_it)</b> - a method that stores a fragment and increments
the iterator to the next tile, guarded by a \ref predicate_iterator_concept
- <b>store_post_increment(fragment)</b> - a method that stores a fragment and increments the
iterator to the next tile
- <b>store(fragment, pred_it)</b> - a const method that stores a fragment, guarded by a \ref
predicate_iterator_concept
- <b>store(fragment)</b> - a method that loads a fragment
@}
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief An iterator implementing \ref tile_store_iterator_concept for storing a tile to memory
* @concept{tile_store_iterator_concept}
*/
template <typename Traits_,
typename Scalar_,
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
typename Index_ = int,
typename FragmentElement_ = Scalar_,
FragmentElementType::Kind FragmentElementType_ = FragmentElementType::kScalar,
typename Skew_ = Shape<0, 0, 0, 0> >
struct TileStoreIterator : public TileIteratorBase<Traits_,
Scalar_,
Advance_,
MemorySpace,
Index_,
FragmentElement_,
FragmentElementType_,
Skew_> {
/// Base class
typedef TileIteratorBase<Traits_,
Scalar_,
Advance_,
MemorySpace,
Index_,
FragmentElement_,
FragmentElementType_,
Skew_>
Base;
/// concept TileTraits
typedef typename Base::Traits Traits;
/// Scalar element
typedef typename Base::Scalar Scalar;
/// Fragment element
typedef typename Base::FragmentElement FragmentElement;
/// Specifies in which dimension post-increment accesses advance.
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
/// Specifies type of iterator fragment storage (Salar or WmmaMatrix)
static FragmentElementType::Kind const kFragmentElementType = Base::kFragmentElementType;
/// Source or destination memory space
static MemorySpace::Kind const kMemorySpace = Base::kMemorySpace;
/// The number of scalars accessed per load/store.
static int const kAccessSize = Base::kAccessSize;
/// Index type
typedef typename Base::Index Index;
/// Long index type
typedef typename Base::LongIndex LongIndex;
/// Skew quantity
typedef typename Base::Skew Skew;
/// Tile shape
typedef typename Base::Tile Tile;
/// Delta
typedef typename Base::Delta Delta;
/// Iterations
typedef typename Base::Iterations Iterations;
/// ThreadOffset functor
typedef typename Base::ThreadOffset ThreadOffset;
/// Fragment type
typedef typename Base::FragmentShape FragmentShape;
/// Memory access type
typedef typename Base::AccessType AccessType;
/// Fragment definition
typedef typename Base::Fragment Fragment;
/// Fragment iterator definition
typedef typename Base::FragmentIterator FragmentIterator;
/// Fragment const iterator definition
typedef typename Base::FragmentConstIterator FragmentConstIterator;
/// Default predicate mask type
typedef typename Base::PredicateVector PredicateVector;
/// Storage object which may be stored to
typedef typename Base::Storage SharedStorage;
/// IteratorBase parameters
typedef typename Base::Params BaseParams;
/// Pointer to underlying type
typedef Scalar *Pointer;
/// Tensor reference for the store iterator
typedef TensorRef<Scalar, 4> TensorRef;
/// Parameters
struct Params : public BaseParams {
/// Pointer to memory
Scalar *pointer;
//
// Methods
//
// Default constructor
CUTLASS_HOST_DEVICE
Params() : pointer(0) {}
// Default constructor
CUTLASS_HOST_DEVICE
Params(Scalar *ptr) : pointer(ptr) { Base::Params::initialize(); }
/// Constructs with a CompactTensorRef<>
CUTLASS_HOST_DEVICE
Params(TensorRef const &ref): pointer(ref.data()) {
Base::Params::initialize(ref.stride());
}
// Default constructor
CUTLASS_HOST_DEVICE
Params(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w) {
initialize(ptr, stride_d, stride_h, stride_w);
}
// Default constructor
CUTLASS_HOST_DEVICE
Params(Scalar *ptr,
Index _stride_d,
Index _stride_h,
Index _stride_w,
Index _inc_d,
Index _inc_h,
Index _inc_w,
Index _inc_advance) {
initialize(ptr, _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
}
/// Initialize params to access storage object
CUTLASS_HOST_DEVICE
int initialize(SharedStorage &storage) {
pointer = &storage[0];
return Base::Params::initialize();
}
/// Initialize params to access storage object
CUTLASS_HOST_DEVICE
int initialize(Scalar *ptr) {
pointer = ptr;
return Base::Params::initialize();
}
/// Initializes params to access a raw pointer
CUTLASS_HOST_DEVICE
int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w) {
Base::Params::initialize(stride_d, stride_h, stride_w);
pointer = ptr;
return 0;
}
/// Initializes params
CUTLASS_HOST_DEVICE
int initialize(Scalar *ptr,
Index _stride_d,
Index _stride_h,
Index _stride_w,
Index _inc_d,
Index _inc_h,
Index _inc_w,
Index _inc_advance) {
pointer = ptr;
Base::Params::initialize(
_stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
return 0;
}
/// Initializes params to default values
CUTLASS_HOST_DEVICE
int initialize() { return Base::Params::initialize(); }
};
//
// Data members
//
/// Parameters structure
Params params;
/// Offset of an individual lane from the start of the tile
Coord<4> thread_offset;
/// The stage.
int stage;
//
// Predicate initialization
//
/// Initializes a predicate vector using a RegularTilePredicateFunctor
template <
/// Predicate iterator
typename PredicateIterator>
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
Coord<3> const &bounds,
Coord<3> const &block_offset = make_Coord(0,
0,
0)) {
Base::initialize_predicates(
predicate_it,
RegularTilePredicateFunctor<typename Traits::Delta>(bounds),
block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
}
/// Initializes a predicate vector using an arbitrary predicate functor
template <
/// Predicate iterator
typename PredicateIterator,
/// Functor computing predicates
typename PredicateFunctor>
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
PredicateFunctor const &functor,
Coord<3> const &block_offset) {
Base::initialize_predicates(
predicate_it,
functor,
block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
}
//
// Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE
TileStoreIterator() {}
/// Constructs a tile store iterator
CUTLASS_HOST_DEVICE
TileStoreIterator(Params const &_params,
Coord<3> const &block_offset = make_Coord(0, 0, 0),
ThreadOffset thread_offset_func = ThreadOffset())
: params(_params), stage(0) {
thread_offset = thread_offset_func();
params.pointer += (block_offset[0] + thread_offset[0]) * params.stride_d +
(block_offset[1] + thread_offset[1]) * params.stride_h +
(block_offset[2] + thread_offset[2]) * params.stride_w;
}
/// Constructs a tile store iterator
CUTLASS_HOST_DEVICE
TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func = ThreadOffset())
: stage(0) {
params.pointer = ptr + thread_offset_func()[2];
params.stride_d = 0;
params.stride_h = 0;
params.stride_w = 1;
params.inc_d = params.inc_h = params.inc_w = params.inc_advance = 0;
}
/// Increment in the D dimension
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
/// Increment in the H dimension
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
/// Increment in the W dimension
CUTLASS_HOST_DEVICE void inc_w() { params.pointer += params.inc_w; }
/// Increment in the next dimension
CUTLASS_HOST_DEVICE void inc_advance() {}
/// Increment the stage.
CUTLASS_HOST_DEVICE void inc_stage() {
if (Tile::kD > 1) {
int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
if (stage == Tile::kD - 1) {
params.pointer -= (Tile::kD - 1) * kStageSize;
stage = 0;
} else {
params.pointer += kStageSize;
stage = stage + 1;
}
}
}
/// Adds a vector offset to the iterator
CUTLASS_HOST_DEVICE TileStoreIterator & operator+=(Coord<3> const &offset) {
params.pointer += offset.template dot<long long>(
make_Coord(params.stride_d, params.stride_h, params.stride_w)
);
return *this;
}
/// Adds a raw offset to the pointer
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset) { params.pointer += offset; }
/// Stores a single fragment element into memory.
CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c) {
int const offset =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
Store<Scalar,
kAccessSize,
kMemorySpace,
kFragmentElementType,
FragmentElement,
Tile::kW,
sizeof(FragmentElement) * kAccessSize>::store(value, params.pointer, offset);
}
/// Stores a fragment and advances to the next tile.
template <typename Fragment, typename PredicateIterator>
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
FragmentConstIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
if (*pred_it) {
store_element(
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
}
}
if (w < Iterations::kW - 1) {
inc_w();
}
}
if (h < Iterations::kH - 1) {
inc_h();
}
}
if (d < Iterations::kD - 1) {
inc_d();
}
}
inc_advance();
}
/// Stores a fragment and advances to the next tile.
template <typename Fragment>
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment) {
typename PredicateVector::TrivialIterator pred_it;
store_post_increment(fragment, pred_it);
}
/// Stores a fragment without advancing the iterator.
template <typename Fragment, typename PredicateIterator>
CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const {
TileStoreIterator _store_it(*this);
_store_it.store_post_increment(fragment, pred_it);
}
/// Stores a fragment without advancing the iterator.
template <typename Fragment>
CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const {
typename PredicateVector::TrivialIterator pred_it;
store(fragment, pred_it);
}
/// Loads a single fragment element from memory
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const {
int const offset =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
Load<Scalar,
kAccessSize,
kMemorySpace,
kFragmentElementType,
FragmentElement,
Tile::kW,
sizeof(FragmentElement) * kAccessSize>::load(value, params.pointer, offset);
}
/// Loads a fragment and advances the iterator to the next tile.
template <typename Fragment, typename PredicateIterator>
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
if (*pred_it) {
load_element(
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
}
}
if (w < Iterations::kW - 1) {
inc_w();
}
}
if (h < Iterations::kH - 1) {
inc_h();
}
}
if (d < Iterations::kD - 1) {
inc_d();
}
}
inc_advance();
}
/// Loads a fragment and advances the iterator to the next tile.
template <typename Fragment>
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment) {
typename PredicateVector::TrivialIterator pred_it;
load_post_increment(fragment, pred_it);
}
/// Loads a fragment without advancing the iterator..
template <typename Fragment, typename PredicateIterator>
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
TileStoreIterator _load_it(*this);
_load_it.load_post_increment(fragment, pred_it);
}
/// Loads a fragment without advancing the iterator..
template <typename Fragment>
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
typename PredicateVector::TrivialIterator pred_it;
load(fragment, pred_it);
}
/// Loads a fragment without advancing the iterator..
template <typename Fragment>
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass