40 template <
typename InputIterator,
typename Fragment>
42 typename InputIterator::FragmentIterator frag_iterator(fragment);
43 for (
int d = 0; d < InputIterator::Iterations::kD; ++d) {
44 for (
int h = 0; h < InputIterator::Iterations::kH; ++h) {
45 for (
int w = 0; w < InputIterator::Iterations::kW; ++w) {
46 for (
int c = 0; c < InputIterator::Iterations::kC; ++c) {
47 if (iterator.valid(d, h, w, c)) {
52 load(reinterpret_cast<typename InputIterator::AccessType &>(
53 frag_iterator.at(d, h, w, c)),
58 if (w < InputIterator::Iterations::kW - 1) {
62 if (h < InputIterator::Iterations::kH - 1) {
66 if (d < InputIterator::Iterations::kD - 1) {
70 iterator.inc_advance();
74 template <
typename InputIterator,
typename Fragment>
76 typename InputIterator::FragmentIterator frag_iterator(fragment);
77 for (
int d = 0; d < InputIterator::Iterations::kD; ++d) {
78 for (
int h = 0; h < InputIterator::Iterations::kH; ++h) {
79 for (
int w = 0; w < InputIterator::Iterations::kW; ++w) {
80 for (
int c = 0; c < InputIterator::Iterations::kC; ++c) {
86 InputIterator::Tile::kC,
87 typename InputIterator::Scalar,
88 InputIterator::kMemorySpace,
89 typename InputIterator::FragmentElement,
90 InputIterator::Tile::kW>::load(frag_iterator.at(d, h, w, c),
100 template <
typename InputIterator,
typename Fragment>
102 typename InputIterator::FragmentIterator frag_iterator(fragment);
103 for (
int h = 0; h < InputIterator::Iterations::kH; ++h) {
104 for (
int w = 0; w < InputIterator::Iterations::kW; ++w) {
105 for (
int c = 0; c < InputIterator::Iterations::kC; ++c) {
111 InputIterator::Tile::kC,
112 typename InputIterator::Scalar,
113 InputIterator::kMemorySpace,
114 typename InputIterator::FragmentElement,
115 InputIterator::Tile::kW>::load(frag_iterator.at(0, h, w, c),
124 template <
typename InputIterator,
typename Fragment,
typename ConstPredicateAdapter>
127 typename InputIterator::Index offset,
128 ConstPredicateAdapter predicate_adapter) {
129 for (
int d = 0; d < InputIterator::Iterations::kD; ++d, iterator.inc_d()) {
130 for (
int h = 0; h < InputIterator::Iterations::kH; ++h, iterator.inc_h()) {
131 for (
int w = 0; w < InputIterator::Iterations::kW; ++w, iterator.inc_w()) {
132 if (predicate_adapter.at(d, h, w, 0)) {
133 int idx = InputIterator::Tile::kC *
134 (w + InputIterator::Iterations::kW * (h + InputIterator::Iterations::kH * d));
137 load(reinterpret_cast<typename InputIterator::AccessType &>(fragment[idx]),
147 template <
typename InputIterator,
typename Fragment>
150 typename InputIterator::Index offset = 0) {
156 template <
typename InputIterator,
typename Fragment,
typename ConstPredicateAdapter>
159 ConstPredicateAdapter pred_it) {
163 template <
typename InputIterator,
typename Fragment,
typename ConstPredicateAdapter>
166 typename InputIterator::Index offset,
167 ConstPredicateAdapter predicate_adapter) {
168 InputIterator iterator(_iterator);
173 template <
typename InputIterator,
typename Fragment>
176 typename InputIterator::Index offset = 0) {
182 template <
typename InputIterator,
typename Fragment,
typename ConstPredicateAdapter>
185 ConstPredicateAdapter pred_it) {
192 template <
typename OutputIterator,
typename Fragment>
194 typename OutputIterator::FragmentIterator frag_iterator(fragment);
195 for (
int d = 0; d < OutputIterator::Iterations::kD; ++d) {
196 for (
int h = 0; h < OutputIterator::Iterations::kH; ++h) {
197 for (
int w = 0; w < OutputIterator::Iterations::kW; ++w) {
198 if (iterator.valid(d, h, w, 0)) {
204 OutputIterator::Tile::kC,
205 OutputIterator::kMemorySpace>::
206 store(reinterpret_cast<typename OutputIterator::AccessType &>(
207 frag_iterator.at(d, h, w, 0)),
211 if (w < OutputIterator::Iterations::kW - 1) {
215 if (h < OutputIterator::Iterations::kH - 1) {
219 if (d < OutputIterator::Iterations::kD - 1) {
223 iterator.inc_advance();
227 template <
typename OutputIterator,
typename Fragment>
229 typename OutputIterator::FragmentConstIterator frag_iterator(fragment);
230 for (
int d = 0; d < OutputIterator::Iterations::kD; ++d) {
231 for (
int h = 0; h < OutputIterator::Iterations::kH; ++h) {
232 for (
int w = 0; w < OutputIterator::Iterations::kW; ++w) {
233 for (
int c = 0; c < OutputIterator::Iterations::kC; ++c) {
239 OutputIterator::Tile::kC,
240 typename OutputIterator::Scalar,
241 OutputIterator::kMemorySpace,
242 typename OutputIterator::FragmentElement,
243 OutputIterator::Tile::kW>::store(frag_iterator.at(d, h, w, c),
255 template <
typename OutputIterator,
typename Fragment,
typename ConstPredicateAdapter>
258 typename OutputIterator::Index offset,
259 ConstPredicateAdapter predicate_adapter) {
260 for (
int d = 0; d < OutputIterator::Iterations::kD; ++d, iterator.inc_d()) {
261 for (
int h = 0; h < OutputIterator::Iterations::kH; ++h, iterator.inc_h()) {
262 for (
int w = 0; w < OutputIterator::Iterations::kW; ++w, iterator.inc_w()) {
263 if (predicate_adapter.at(d, h, w, 0)) {
264 int idx = OutputIterator::Tile::kC *
265 (w + OutputIterator::Iterations::kW * (h + OutputIterator::Iterations::kH * d));
268 OutputIterator::Tile::kC,
269 OutputIterator::kMemorySpace>::
270 store(reinterpret_cast<typename OutputIterator::AccessType const &>(fragment[idx]),
280 template <
typename OutputIterator,
typename Fragment>
283 typename OutputIterator::Index offset = 0) {
289 template <
typename OutputIterator,
typename Fragment,
typename ConstPredicateAdapter>
292 ConstPredicateAdapter pred_it) {
297 template <
typename OutputIterator,
typename Fragment,
typename ConstPredicateAdapter>
300 typename OutputIterator::Index offset,
301 ConstPredicateAdapter predicate_adapter) {
302 OutputIterator iterator(_iterator);
307 template <
typename OutputIterator,
typename Fragment>
310 typename OutputIterator::Index offset = 0) {
316 template <
typename OutputIterator,
typename Fragment,
typename ConstPredicateAdapter>
319 ConstPredicateAdapter pred_it) {
Definition: fragment_load_store.h:43
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from a shared memory input iterator.
Definition: iterator_access.h:75
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator, Fragment const &fragment, typename OutputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
Stores a fragment to an output iterator, masked by a predicate iterator.
Definition: iterator_access.h:256
Defines accessors for loading and storing fragments to memory efficiently.
static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
The load function.
Definition: load_store.h:59
A template defining Fragment Concept.
Definition: fragment.h:99
Definition: load_store.h:131
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
static CUTLASS_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:211
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator, Fragment &fragment, typename InputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
Loads a fragment from an input iterator, masked by a predicate iterator.
Definition: iterator_access.h:125
Defines abstractions for efficiently loading and storing vectors to memory.
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment)
Stores a fragment to a shared memory output iterator.
Definition: iterator_access.h:228
Element_ Element
The element.
Definition: fragment.h:108
Always returns true predicate.
Definition: predicate_vector.h:426
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment)
Stores a fragment to an output iterator.
Definition: iterator_access.h:193
Definition: fragment_load_store.h:91
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from an input iterator.
Definition: iterator_access.h:41
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.