Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
iterator_access.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
31 #include <cutlass/load_store.h>
33 #include <cutlass/shape.h>
34 
35 namespace cutlass {
36 
38 
40 template <typename InputIterator, typename Fragment>
41 CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
42  typename InputIterator::FragmentIterator frag_iterator(fragment);
43  for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
44  for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
45  for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
46  for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
47  if (iterator.valid(d, h, w, c)) {
48  int const offset =
50  0, 0, w, c);
52  load(reinterpret_cast<typename InputIterator::AccessType &>(
53  frag_iterator.at(d, h, w, c)),
54  iterator.data(),
55  offset);
56  }
57  }
58  if (w < InputIterator::Iterations::kW - 1) {
59  iterator.inc_w();
60  }
61  }
62  if (h < InputIterator::Iterations::kH - 1) {
63  iterator.inc_h();
64  }
65  }
66  if (d < InputIterator::Iterations::kD - 1) {
67  iterator.inc_d();
68  }
69  }
70  iterator.inc_advance();
71 }
72 
74 template <typename InputIterator, typename Fragment>
75 CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment) {
76  typename InputIterator::FragmentIterator frag_iterator(fragment);
77  for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
78  for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
79  for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
80  for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
81  int const offset =
83  d, h, w, c);
84 
85  FragmentLoad<InputIterator::kIteratorFragment,
86  InputIterator::Tile::kC,
87  typename InputIterator::Scalar,
88  InputIterator::kMemorySpace,
89  typename InputIterator::FragmentElement,
90  InputIterator::Tile::kW>::load(frag_iterator.at(d, h, w, c),
91  iterator.data(),
92  offset);
93  }
94  }
95  }
96  }
97 }
98 
100 template <typename InputIterator, typename Fragment>
101 CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment, int d) {
102  typename InputIterator::FragmentIterator frag_iterator(fragment);
103  for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
104  for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
105  for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
106  int const offset =
108  d, h, w, c);
109 
110  FragmentLoad<InputIterator::kIteratorFragment,
111  InputIterator::Tile::kC,
112  typename InputIterator::Scalar,
113  InputIterator::kMemorySpace,
114  typename InputIterator::FragmentElement,
115  InputIterator::Tile::kW>::load(frag_iterator.at(0, h, w, c),
116  iterator.data(),
117  offset);
118  }
119  }
120  }
121 }
122 
124 template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
126  Fragment &fragment,
127  typename InputIterator::Index offset,
128  ConstPredicateAdapter predicate_adapter) {
129  for (int d = 0; d < InputIterator::Iterations::kD; ++d, iterator.inc_d()) {
130  for (int h = 0; h < InputIterator::Iterations::kH; ++h, iterator.inc_h()) {
131  for (int w = 0; w < InputIterator::Iterations::kW; ++w, iterator.inc_w()) {
132  if (predicate_adapter.at(d, h, w, 0)) {
133  int idx = InputIterator::Tile::kC *
134  (w + InputIterator::Iterations::kW * (h + InputIterator::Iterations::kH * d));
135 
137  load(reinterpret_cast<typename InputIterator::AccessType &>(fragment[idx]),
138  iterator.data(),
139  offset);
140  }
141  }
142  }
143  }
144 }
145 
147 template <typename InputIterator, typename Fragment>
149  Fragment &fragment,
150  typename InputIterator::Index offset = 0) {
152  iterator_load_post_increment(iterator, fragment, offset, pred);
153 }
154 
156 template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
158  Fragment &fragment,
159  ConstPredicateAdapter pred_it) {
160  iterator_load_post_increment(iterator, fragment, 0, pred_it);
161 }
162 
163 template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
164 CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &_iterator,
165  Fragment &fragment,
166  typename InputIterator::Index offset,
167  ConstPredicateAdapter predicate_adapter) {
168  InputIterator iterator(_iterator);
169  iterator_load_post_increment(iterator, fragment, offset, predicate_adapter);
170 }
171 
173 template <typename InputIterator, typename Fragment>
174 CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
175  Fragment &fragment,
176  typename InputIterator::Index offset = 0) {
178  iterator_load(iterator, fragment, offset, pred);
179 }
180 
182 template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
183 CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
184  Fragment &fragment,
185  ConstPredicateAdapter pred_it) {
186  iterator_load(iterator, fragment, 0, pred_it);
187 }
188 
190 
192 template <typename OutputIterator, typename Fragment>
193 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
194  typename OutputIterator::FragmentIterator frag_iterator(fragment);
195  for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
196  for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
197  for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
198  if (iterator.valid(d, h, w, 0)) {
199  int const offset =
201  d, h, w, 0);
202 
203  Store<typename Fragment::Element,
204  OutputIterator::Tile::kC,
205  OutputIterator::kMemorySpace>::
206  store(reinterpret_cast<typename OutputIterator::AccessType &>(
207  frag_iterator.at(d, h, w, 0)),
208  iterator.data(),
209  offset);
210  }
211  if (w < OutputIterator::Iterations::kW - 1) {
212  iterator.inc_w();
213  }
214  }
215  if (h < OutputIterator::Iterations::kH - 1) {
216  iterator.inc_h();
217  }
218  }
219  if (d < OutputIterator::Iterations::kD - 1) {
220  iterator.inc_d();
221  }
222  }
223  iterator.inc_advance();
224 }
225 
227 template <typename OutputIterator, typename Fragment>
228 CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment) {
229  typename OutputIterator::FragmentConstIterator frag_iterator(fragment);
230  for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
231  for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
232  for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
233  for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
234  int const offset =
236  d, h, w, c);
237 
238  FragmentStore<OutputIterator::kIteratorFragment,
239  OutputIterator::Tile::kC,
240  typename OutputIterator::Scalar,
241  OutputIterator::kMemorySpace,
242  typename OutputIterator::FragmentElement,
243  OutputIterator::Tile::kW>::store(frag_iterator.at(d, h, w, c),
244  iterator.data(),
245  offset);
246  }
247  }
248  }
249  }
250 }
251 
253 
255 template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
257  Fragment const &fragment,
258  typename OutputIterator::Index offset,
259  ConstPredicateAdapter predicate_adapter) {
260  for (int d = 0; d < OutputIterator::Iterations::kD; ++d, iterator.inc_d()) {
261  for (int h = 0; h < OutputIterator::Iterations::kH; ++h, iterator.inc_h()) {
262  for (int w = 0; w < OutputIterator::Iterations::kW; ++w, iterator.inc_w()) {
263  if (predicate_adapter.at(d, h, w, 0)) {
264  int idx = OutputIterator::Tile::kC *
265  (w + OutputIterator::Iterations::kW * (h + OutputIterator::Iterations::kH * d));
266 
267  Store<typename Fragment::Element,
268  OutputIterator::Tile::kC,
269  OutputIterator::kMemorySpace>::
270  store(reinterpret_cast<typename OutputIterator::AccessType const &>(fragment[idx]),
271  iterator.data(),
272  offset);
273  }
274  }
275  }
276  }
277 }
278 
280 template <typename OutputIterator, typename Fragment>
282  Fragment const &fragment,
283  typename OutputIterator::Index offset = 0) {
285  iterator_store_post_increment(iterator, fragment, offset, pred);
286 }
287 
289 template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
291  Fragment const &fragment,
292  ConstPredicateAdapter pred_it) {
293  iterator_store_post_increment(iterator, fragment, 0, pred_it);
294 }
295 
297 template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
298 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &_iterator,
299  Fragment const &fragment,
300  typename OutputIterator::Index offset,
301  ConstPredicateAdapter predicate_adapter) {
302  OutputIterator iterator(_iterator);
303  iterator_store_post_increment(iterator, fragment, offset, predicate_adapter);
304 }
305 
307 template <typename OutputIterator, typename Fragment>
308 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
309  Fragment const &fragment,
310  typename OutputIterator::Index offset = 0) {
312  iterator_store(iterator, fragment, offset, pred);
313 }
314 
316 template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
317 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
318  Fragment const &fragment,
319  ConstPredicateAdapter pred_it) {
320  iterator_store(iterator, fragment, 0, pred_it);
321 }
322 
324 
325 } // namespace cutlass
Definition: fragment_load_store.h:43
Definition: convert.h:33
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from a shared memory input iterator.
Definition: iterator_access.h:75
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator, Fragment const &fragment, typename OutputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
Stores a fragment to an output iterator, masked by a predicate iterator.
Definition: iterator_access.h:256
Defines accessors for loading and storing fragments to memory efficiently.
static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
The load function.
Definition: load_store.h:59
A template defining Fragment Concept.
Definition: fragment.h:99
Definition: load_store.h:131
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
static CUTLASS_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:211
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator, Fragment &fragment, typename InputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
Loads a fragment from an input iterator, masked by a predicate iterator.
Definition: iterator_access.h:125
Defines abstractions for efficiently loading and storing vectors to memory.
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment)
Stores a fragment to a shared memory output iterator.
Definition: iterator_access.h:228
Element_ Element
The element.
Definition: fragment.h:108
Always returns true predicate.
Definition: predicate_vector.h:426
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment)
Stores a fragment to an output iterator.
Definition: iterator_access.h:193
Definition: fragment_load_store.h:91
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from an input iterator.
Definition: iterator_access.h:41
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.