Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tile_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/coord.h"
32 #include "cutlass/tensor_ref.h"
33 #include "cutlass/fragment.h"
34 #include "cutlass/load_store.h"
36 #include "cutlass/vector.h"
37 #include <cstdio>
38 
39 namespace cutlass {
40 
42 
61 
65  enum Kind { kD, kH, kW };
66 };
68 
73 template <typename Tile_,
74  typename Delta_,
75  typename Iterations_,
76  typename ThreadOffset_,
77  int AccessSize>
78 struct TileTraits {
80  typedef Tile_ Tile;
81 
83  typedef Delta_ Delta;
84 
86  typedef Iterations_ Iterations;
87 
89  //
90  // ThreadOffset should be a functor defined like:
91  //
92  // struct ThreadOffsetExample {
93  // CUTLASS_DEVICE
94  // Coord<4> operator()() const {
95  // return make_Coord(0, threadIdx.y, threadIdx.x, 0);
96  // }
97  // };
98  //
99  typedef ThreadOffset_ ThreadOffset;
100 
103 
105  static int const kAccessSize = AccessSize;
106 };
107 
109 
111 template <typename Delta_>
113  typedef Delta_ Delta;
114 
117 
121 
124  bool operator()(Coord<3> iteration, Coord<3> offset) const {
125  return (iteration[0] * Delta::kD + offset[0] < bounds[0]) &&
126  (iteration[1] * Delta::kH + offset[1] < bounds[1]) &&
127  (iteration[2] * Delta::kW + offset[2] < bounds[2]);
128  }
129 };
130 
132 
133 template <typename T>
134 struct DumpType {};
136 template <typename Traits_,
137  typename Scalar_,
140  typename Index_ = int,
141  typename FragmentElement_ = Scalar_,
143  typename Skew_ = Shape<0, 0, 0, 0> >
146  typedef Traits_ Traits;
147 
149  typedef Scalar_ Scalar;
150 
152  typedef FragmentElement_ FragmentElement;
153 
155  static IteratorAdvance::Kind const kAdvance = Advance_;
156 
158  static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_;
159 
162 
164  typedef Index_ Index;
165 
167  typedef Skew_ Skew;
168 
170  typedef typename Traits::Tile Tile;
171 
173  typedef typename Traits::Delta Delta;
174 
176  typedef typename Traits::ImmediateOffsetStrides ImmediateOffsetStrides;
177 
179  typedef typename Traits::Iterations Iterations;
180 
182  typedef typename Traits::ThreadOffset ThreadOffset;
183 
185  static int const kAccessSize = Traits::kAccessSize;
186 
189 
191  static int const kFragmentSize =
197 
204 
207 
208  //
209  // Params struct
210  //
211 
213  struct Params {
214 
215  //
216  // Dat members
217  //
218 
219  long long stride_d;
222 
223  long long inc_d;
226 
227  long long inc_advance;
228 
229  //
230  // Methods
231  //
232 
235  Params() : stride_d(0), stride_h(0), stride_w(0), inc_d(0), inc_h(0), inc_w(0) {}
236 
239  Params(long long _stride_d,
240  Index _stride_h,
241  Index _stride_w,
242  long long _inc_d,
243  Index _inc_h,
244  Index _inc_w,
245  long long _inc_advance)
246  : stride_d(_stride_d),
247  stride_h(_stride_h),
248  stride_w(_stride_w),
249  inc_d(_inc_d),
250  inc_h(_inc_h),
251  inc_w(_inc_w),
252  inc_advance(_inc_advance) {}
253 
256  Params(Coord<4> const &stride) {
257  initialize(stride);
258  }
259 
262  int initialize(long long _stride_d,
263  Index _stride_h,
264  Index _stride_w,
265  long long _inc_d,
266  Index _inc_h,
267  Index _inc_w,
268  long long _inc_advance) {
269  stride_d = _stride_d;
270  stride_h = _stride_h;
271  stride_w = _stride_w;
272 
273  inc_d = _inc_d;
274  inc_h = _inc_h;
275  inc_w = _inc_w;
276  inc_advance = _inc_advance;
277 
278  return 0;
279  }
280 
283  int initialize(Coord<4> const &stride) {
284  return initialize(stride[0], stride[1], stride[2]);
285  }
286 
289  int initialize(long long _stride_d, Index _stride_h, Index _stride_w) {
290  stride_d = _stride_d;
291  stride_h = _stride_h;
292  stride_w = _stride_w;
293 
294  inc_w = stride_w * Delta::kW;
295  inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
296  inc_d = stride_d * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) -
297  stride_w * Delta::kW * (Iterations::kW - 1);
298 
299  inc_advance = 0;
300 
301  if (kAdvance == IteratorAdvance::kH) {
302  // Advance in the H dimension.
303  inc_advance = Tile::kH * stride_h;
304  } else if (kAdvance == IteratorAdvance::kW) {
305  // Advance in the W dimension.
306  inc_advance = Tile::kW * stride_w;
307 
308  } else {
309  // Advance in the D dimension.
310  inc_advance = Tile::kD * stride_d;
311  }
312 
313  inc_advance -= stride_d * Delta::kD * (Iterations::kD - 1) +
314  stride_h * Delta::kH * (Iterations::kH - 1) +
315  stride_w * Delta::kW * (Iterations::kW - 1);
316 
317  return 0;
318  }
319 
322  stride_d = 0;
323  stride_h = 0;
324  stride_w = 1;
325 
326  inc_advance = 0;
327  inc_d = inc_h = inc_w = 0;
328 
329  return 0;
330  }
331  };
332 
334  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
335 
336  //
337  // Static function members
338  //
339 
341  template <typename PredicateIterator, typename PredicateFunctor>
342  CUTLASS_HOST_DEVICE static void initialize_predicates(PredicateIterator predicate_it,
343  PredicateFunctor const &predicate_func,
344  Coord<3> const &offset) {
346  for (int d = 0; d < Iterations::kD; ++d) {
348  for (int h = 0; h < Iterations::kH; ++h) {
350  for (int w = 0; w < Iterations::kW; ++w) {
351  bool enable = predicate_func(make_Coord(d, h, w), offset);
352  predicate_it.set(enable);
353  ++predicate_it;
354  }
355  }
356  }
357  }
358 };
359 
361 
385 
391 template <typename Traits_,
392  typename Scalar_,
395  typename Index_ = int,
396  typename FragmentElement_ = Scalar_,
398  typename Skew_ = Shape<0, 0, 0, 0> >
399 struct TileLoadIterator : public TileIteratorBase<Traits_,
400  Scalar_,
401  Advance_,
402  MemorySpace,
403  Index_,
404  FragmentElement_,
405  FragmentElementType_,
406  Skew_> {
408  typedef TileIteratorBase<Traits_,
409  Scalar_,
410  Advance_,
411  MemorySpace,
412  Index_,
413  FragmentElement_,
414  FragmentElementType_,
415  Skew_>
417 
419  typedef typename Base::Traits Traits;
420 
422  typedef typename Base::Scalar Scalar;
423 
425  typedef FragmentElement_ FragmentElement;
426 
429 
431  static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_;
432 
435 
437  typedef typename Base::Index Index;
438 
440  typedef typename Base::Skew Skew;
441 
443  typedef typename Base::Tile Tile;
444 
446  typedef typename Base::Delta Delta;
447 
449  typedef typename Base::Iterations Iterations;
450 
453 
456 
458  typedef typename Base::AccessType AccessType;
459 
461  static int const kAccessSize = Base::kAccessSize;
462 
464  typedef typename Base::Fragment Fragment;
465 
468 
471 
474 
476  typedef typename Base::Storage SharedStorage;
477 
479  typedef typename Base::Params BaseParams;
480 
482  enum { kRequiresLoadFence = Tile::kD == 1 };
483 
485  typedef Scalar const *Pointer;
486 
489 
491  struct Params : public BaseParams {
493  Scalar const *pointer;
494 
495  //
496  // Methods
497  //
498 
502 
505  Params(Scalar const *ptr) : pointer(ptr) { Base::Params::initialize(); }
506 
509  Params(TensorRef const &ref): pointer(ref.data()) {
510  Base::Params::initialize(ref.stride());
511  }
512 
515  Params(Scalar const *ptr,
516  long long _stride_d,
517  Index _stride_h,
518  Index _stride_w,
519  long long _inc_d,
520  Index _inc_h,
521  Index _inc_w,
522  Index _inc_advance)
523  : pointer(ptr) {
525  _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
526  }
527 
530  Params(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w)
531  : pointer(ptr) {
533  }
534 
537  int initialize(TensorRef const &ref) {
538  pointer = ref.data();
539  return Base::Params::initialize(ref.stride());
540  }
541 
544  int initialize(SharedStorage const &storage) {
545  pointer = &storage[0];
547  return 0;
548  }
549 
552  int initialize(Scalar const *ptr) {
553  pointer = ptr;
555  return 0;
556  }
557 
560  int initialize(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w) {
562  pointer = ptr;
563  return 0;
564  }
565 
568  int initialize(Scalar const *ptr,
569  long long _stride_d,
570  Index _stride_h,
571  Index _stride_w,
572  long long _inc_d,
573  Index _inc_h,
574  Index _inc_w,
575  Index _inc_advance) {
576  pointer = ptr;
578  _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
579  return 0;
580  }
581 
582  // Initializes params to default values
585  };
586 
587  //
588  // Data members
589  //
590 
592  Params params;
593 
596 
598  int stage;
599 
600  //
601  // Predicate initialization
602  //
603 
605  template <
607  typename PredicateIterator>
608  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
609  Coord<3> const &bounds,
610  Coord<3> const &block_offset = make_Coord(0,
611  0,
612  0)) {
614  predicate_it,
616  block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
617  }
618 
620  template <
622  typename PredicateIterator,
624  typename PredicateFunctor>
625  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
626  PredicateFunctor const &functor,
627  Coord<3> const &block_offset) {
629  predicate_it,
630  functor,
631  block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
632  }
633 
634  //
635  // Methods
636  //
637 
641 
644  TileLoadIterator(Params const &_params,
645  Coord<3> const &block_offset = make_Coord(0, 0, 0),
646  ThreadOffset thread_offset_func = ThreadOffset())
647  : params(_params), stage(0) {
648  thread_offset = thread_offset_func();
649 
650  Index pointer_offset = Index((block_offset[0] + thread_offset[0]) * params.stride_d) +
651  Index((block_offset[1] + thread_offset[1]) * params.stride_h) +
652  Index((block_offset[2] + thread_offset[2]) * params.stride_w);
653 
654  params.pointer += pointer_offset;
655  }
656 
659  TileLoadIterator(Params const &,
660  Scalar const *ptr,
661  Coord<3> const &block_offset = make_Coord(0, 0, 0),
662  ThreadOffset thread_offset_func = ThreadOffset())
663  : stage(0) {
664  params.pointer = ptr + thread_offset_func()[2];
665 
666  params.stride_d = 0;
667  params.stride_h = 0;
668  params.stride_w = 1;
669 
671  }
672 
675 
678 
681 
684 
686  CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const {
687  int const offset =
689  Load<Scalar,
690  kAccessSize,
691  kMemorySpace,
694  Tile::kW,
695  sizeof(FragmentElement) * kAccessSize>::load(value, params.pointer, offset);
696  }
697 
700  if (Tile::kD > 1) {
701  int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
702  if (stage == Tile::kD - 1) {
703  params.pointer -= (Tile::kD - 1) * kStageSize;
704  stage = 0;
705  } else {
706  params.pointer += kStageSize;
707  stage = stage + 1;
708  }
709  }
710  }
711 
714  long long _offset = offset.template dot<long long>(
716  );
717 
718  params.pointer += _offset;
719  return *this;
720  }
721 
724 
726  Index stride = params.stride_h;
727  if (kAdvance == IteratorAdvance::kW) {
728  stride = params.stride_w;
729  }
730  return stride;
731  }
732 
734  template <typename Fragment, typename PredicateIterator>
735  CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
736  FragmentIterator frag_iterator(fragment);
737 
738  for (int d = 0; d < Iterations::kD; ++d) {
739  for (int h = 0; h < Iterations::kH; ++h) {
740  for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
741  for (int c = 0; c < Iterations::kC; ++c) {
742  if (*pred_it) {
743  load_element(
744  reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
745  }
746  }
747  if (w < Iterations::kW - 1) {
748  inc_w();
749  }
750  }
751  if (h < Iterations::kH - 1) {
752  inc_h();
753  }
754  }
755  if (d < Iterations::kD - 1) {
756  inc_d();
757  }
758  }
759  inc_advance();
760  }
761 
763  template <typename Fragment>
765  typename PredicateVector::TrivialIterator pred_it;
766  load_post_increment(fragment, pred_it);
767  }
768 
770  template <typename Fragment, typename PredicateIterator>
771  CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
772  TileLoadIterator _load_it(*this);
773  _load_it.load_post_increment(fragment, pred_it);
774  }
775 
777  template <typename Fragment>
778  CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
779  typename PredicateVector::TrivialIterator pred_it;
780  load(fragment, pred_it);
781  }
782 
784  template <typename Fragment>
785  CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
786  FragmentIterator frag_iterator(fragment);
787  for (int h = 0; h < Iterations::kH; ++h) {
788  for (int w = 0; w < Iterations::kW; ++w) {
789  for (int c = 0; c < Iterations::kC; ++c) {
790  load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
791  }
792  }
793  }
794  }
795 };
796 
798 
822 
828 template <typename Traits_,
829  typename Scalar_,
832  typename Index_ = int,
833  typename FragmentElement_ = Scalar_,
835  typename Skew_ = Shape<0, 0, 0, 0> >
836 struct TileStoreIterator : public TileIteratorBase<Traits_,
837  Scalar_,
838  Advance_,
839  MemorySpace,
840  Index_,
841  FragmentElement_,
842  FragmentElementType_,
843  Skew_> {
845  typedef TileIteratorBase<Traits_,
846  Scalar_,
847  Advance_,
848  MemorySpace,
849  Index_,
850  FragmentElement_,
851  FragmentElementType_,
852  Skew_>
854 
856  typedef typename Base::Traits Traits;
857 
859  typedef typename Base::Scalar Scalar;
860 
863 
866 
869 
872 
874  static int const kAccessSize = Base::kAccessSize;
875 
877  typedef typename Base::Index Index;
878 
880  typedef typename Base::Skew Skew;
881 
883  typedef typename Base::Tile Tile;
884 
886  typedef typename Base::Delta Delta;
887 
889  typedef typename Base::Iterations Iterations;
890 
893 
896 
898  typedef typename Base::AccessType AccessType;
899 
901  typedef typename Base::Fragment Fragment;
902 
905 
908 
911 
913  typedef typename Base::Storage SharedStorage;
914 
916  typedef typename Base::Params BaseParams;
917 
919  typedef Scalar *Pointer;
920 
923 
925  struct Params : public BaseParams {
928 
929  //
930  // Methods
931  //
932 
933  // Default constructor
935  Params() : pointer(0) {}
936 
937  // Default constructor
940 
943  Params(TensorRef const &ref): pointer(ref.data()) {
944  Base::Params::initialize(ref.stride());
945  }
946 
947  // Default constructor
951  }
952 
953  // Default constructor
956  long long _stride_d,
957  Index _stride_h,
958  Index _stride_w,
959  long long _inc_d,
960  Index _inc_h,
961  Index _inc_w,
962  Index _inc_advance) {
963  initialize(ptr, _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
964  }
965 
968  int initialize(SharedStorage &storage) {
969  pointer = &storage[0];
970  return Base::Params::initialize();
971  }
972 
975  int initialize(Scalar *ptr) {
976  pointer = ptr;
977  return Base::Params::initialize();
978  }
979 
984  pointer = ptr;
985  return 0;
986  }
987 
990  int initialize(Scalar *ptr,
991  long long _stride_d,
992  Index _stride_h,
993  Index _stride_w,
994  long long _inc_d,
995  Index _inc_h,
996  Index _inc_w,
997  Index _inc_advance) {
998  pointer = ptr;
1000  _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
1001  return 0;
1002  }
1003 
1007  };
1008 
1009  //
1010  // Data members
1011  //
1012 
1015 
1018 
1020  int stage;
1021 
1022  //
1023  // Predicate initialization
1024  //
1025 
1027  template <
1029  typename PredicateIterator>
1030  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
1031  Coord<3> const &bounds,
1032  Coord<3> const &block_offset = make_Coord(0,
1033  0,
1034  0)) {
1036  predicate_it,
1038  block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
1039  }
1040 
1042  template <
1044  typename PredicateIterator,
1046  typename PredicateFunctor>
1047  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
1048  PredicateFunctor const &functor,
1049  Coord<3> const &block_offset) {
1051  predicate_it,
1052  functor,
1053  block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
1054  }
1055 
1056  //
1057  // Methods
1058  //
1059 
1063 
1066  TileStoreIterator(Params const &_params,
1067  Coord<3> const &block_offset = make_Coord(0, 0, 0),
1068  ThreadOffset thread_offset_func = ThreadOffset())
1069  : params(_params), stage(0) {
1070  thread_offset = thread_offset_func();
1071 
1072  params.pointer += (block_offset[0] + thread_offset[0]) * params.stride_d +
1073  (block_offset[1] + thread_offset[1]) * params.stride_h +
1074  (block_offset[2] + thread_offset[2]) * params.stride_w;
1075  }
1076 
1079  TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func = ThreadOffset())
1080  : stage(0) {
1081  params.pointer = ptr + thread_offset_func()[2];
1082  params.stride_d = 0;
1083  params.stride_h = 0;
1084  params.stride_w = 1;
1085 
1087  }
1088 
1091 
1094 
1097 
1100 
1103  if (Tile::kD > 1) {
1104  int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
1105  if (stage == Tile::kD - 1) {
1106  params.pointer -= (Tile::kD - 1) * kStageSize;
1107  stage = 0;
1108  } else {
1109  params.pointer += kStageSize;
1110  stage = stage + 1;
1111  }
1112  }
1113  }
1114 
1117  params.pointer += offset.template dot<long long>(
1119  );
1120  return *this;
1121  }
1122 
1125 
1127  CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c) {
1128  int const offset =
1130  Store<Scalar,
1131  kAccessSize,
1132  kMemorySpace,
1135  Tile::kW,
1136  sizeof(FragmentElement) * kAccessSize>::store(value, params.pointer, offset);
1137  }
1138 
1140  template <typename Fragment, typename PredicateIterator>
1141  CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
1142  FragmentConstIterator frag_iterator(fragment);
1143 
1144  for (int d = 0; d < Iterations::kD; ++d) {
1145  for (int h = 0; h < Iterations::kH; ++h) {
1146  for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
1147  for (int c = 0; c < Iterations::kC; ++c) {
1148  if (*pred_it) {
1149  store_element(
1150  reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
1151  }
1152  }
1153  if (w < Iterations::kW - 1) {
1154  inc_w();
1155  }
1156  }
1157  if (h < Iterations::kH - 1) {
1158  inc_h();
1159  }
1160  }
1161  if (d < Iterations::kD - 1) {
1162  inc_d();
1163  }
1164  }
1165  inc_advance();
1166  }
1167 
1169  template <typename Fragment>
1171  typename PredicateVector::TrivialIterator pred_it;
1172  store_post_increment(fragment, pred_it);
1173  }
1174 
1176  template <typename Fragment, typename PredicateIterator>
1177  CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const {
1178  TileStoreIterator _store_it(*this);
1179  _store_it.store_post_increment(fragment, pred_it);
1180  }
1181 
1183  template <typename Fragment>
1184  CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const {
1185  typename PredicateVector::TrivialIterator pred_it;
1186  store(fragment, pred_it);
1187  }
1188 
1190  CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const {
1191  int const offset =
1193 
1194  Load<Scalar,
1195  kAccessSize,
1196  kMemorySpace,
1199  Tile::kW,
1200  sizeof(FragmentElement) * kAccessSize>::load(value, params.pointer, offset);
1201  }
1202 
1204  template <typename Fragment, typename PredicateIterator>
1205  CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
1206  FragmentIterator frag_iterator(fragment);
1207 
1208  for (int d = 0; d < Iterations::kD; ++d) {
1209  for (int h = 0; h < Iterations::kH; ++h) {
1210  for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
1211  for (int c = 0; c < Iterations::kC; ++c) {
1212  if (*pred_it) {
1213  load_element(
1214  reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
1215  }
1216  }
1217  if (w < Iterations::kW - 1) {
1218  inc_w();
1219  }
1220  }
1221  if (h < Iterations::kH - 1) {
1222  inc_h();
1223  }
1224  }
1225  if (d < Iterations::kD - 1) {
1226  inc_d();
1227  }
1228  }
1229  inc_advance();
1230  }
1231 
1233  template <typename Fragment>
1235  typename PredicateVector::TrivialIterator pred_it;
1236  load_post_increment(fragment, pred_it);
1237  }
1238 
1240  template <typename Fragment, typename PredicateIterator>
1241  CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
1242  TileStoreIterator _load_it(*this);
1243  _load_it.load_post_increment(fragment, pred_it);
1244  }
1245 
1247  template <typename Fragment>
1248  CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
1249  typename PredicateVector::TrivialIterator pred_it;
1250  load(fragment, pred_it);
1251  }
1252 
1254  template <typename Fragment>
1255  CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
1256  FragmentIterator frag_iterator(fragment);
1257  for (int h = 0; h < Iterations::kH; ++h) {
1258  for (int w = 0; w < Iterations::kW; ++w) {
1259  for (int c = 0; c < Iterations::kC; ++c) {
1260  load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
1261  }
1262  }
1263  }
1264  }
1265 };
1266 
1268 
1269 } // namespace cutlass
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:990
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:683
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: tile_iterator.h:1190
Vectorize< FragmentElement, kAccessSize >::Type AccessType
The elements loaded/store by one instruction.
Definition: tile_iterator.h:188
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:883
Delta_ Delta
Definition: tile_iterator.h:113
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:910
CUTLASS_HOST_DEVICE Params()
Initialize params to access storage object.
Definition: tile_iterator.h:501
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:644
CUTLASS_HOST_DEVICE int initialize(SharedStorage const &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:544
Tile_ Tile
Shape of the tile.
Definition: tile_iterator.h:80
Index_ Index
Index type.
Definition: tile_iterator.h:164
Definition: convert.h:33
Defines a structure containing strides, bounds, and a pointer to tensor data.
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:449
CUTLASS_HOST_DEVICE int initialize(Coord< 4 > const &stride)
Initializes the parameters object from a vector of strides.
Definition: tile_iterator.h:283
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: tile_iterator.h:334
Skew_ Skew
Skew quantity.
Definition: tile_iterator.h:167
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:461
Enum to specify which memory space data resides in.
Definition: load_store.h:38
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:1234
Base::Index Index
Index type.
Definition: tile_iterator.h:877
Base::Storage SharedStorage
Storage object that may be loaded from.
Definition: tile_iterator.h:476
int stage
The stage.
Definition: tile_iterator.h:1020
Base::Tile Tile
Tile shape.
Definition: tile_iterator.h:443
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initializes params.
Definition: tile_iterator.h:568
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:199
Scalar * Pointer
Pointer to underlying type.
Definition: tile_iterator.h:919
Traits::ThreadOffset ThreadOffset
Thread offset.
Definition: tile_iterator.h:182
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE int initialize(long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, long long _inc_advance)
Initializes params.
Definition: tile_iterator.h:262
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:318
Shape< 0, 0, 0, 0 > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: tile_iterator.h:102
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:470
A template defining Tile Traits Concept.
Definition: tile_iterator.h:78
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:425
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > Base
Base class.
Definition: tile_iterator.h:416
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: tile_iterator.h:608
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:196
Traits::Iterations Iterations
Iterations.
Definition: tile_iterator.h:179
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d)
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:785
Base::Delta Delta
Delta.
Definition: tile_iterator.h:446
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:735
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:584
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:434
Traits::ImmediateOffsetStrides ImmediateOffsetStrides
The strides in each dimension between different loads/stores.
Definition: tile_iterator.h:176
long long inc_d
Definition: tile_iterator.h:223
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:452
CUTLASS_HOST_DEVICE int initialize()
Initializes params to default values.
Definition: tile_iterator.h:1006
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:464
Definition: tile_iterator.h:65
long long inc_advance
Definition: tile_iterator.h:227
Base::Storage SharedStorage
Storage object which may be stored to.
Definition: tile_iterator.h:913
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:871
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: tile_iterator.h:1030
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:1090
ThreadOffset_ ThreadOffset
Functor that returns the logical coordinate of each entity&#39;s initial offset in the tile...
Definition: tile_iterator.h:99
Iterator that always returns true.
Definition: predicate_vector.h:309
CUTLASS_HOST_DEVICE Params(Coord< 4 > const &stride)
Constructs params with a stride vector.
Definition: tile_iterator.h:256
Scalar * pointer
Pointer to memory.
Definition: tile_iterator.h:927
CUTLASS_HOST_DEVICE Params(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w)
Definition: tile_iterator.h:949
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:895
Kind
Definition: load_store.h:39
PredicateVector< ShapeCount< Iterations >::kCount > PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:206
CUTLASS_HOST_DEVICE Index stride_advance(void)
Definition: tile_iterator.h:725
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:901
TensorRef< Scalar, 4 > TensorRef
Tensor reference for the store iterator.
Definition: tile_iterator.h:922
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:419
Definition: load_store.h:178
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1248
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: tile_iterator.h:625
CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: tile_iterator.h:1127
FragmentIterator::FragmentShape FragmentShape
The shape of the fragment.
Definition: tile_iterator.h:203
Scalar const * Pointer
The pointer type.
Definition: tile_iterator.h:485
static IteratorAdvance::Kind const kAdvance
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:155
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:62
Index stride_h
Definition: tile_iterator.h:220
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
Parameters.
Definition: tile_iterator.h:925
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
Traits_ Traits
concept TileTraits
Definition: tile_iterator.h:146
Params params
Parameters structure.
Definition: tile_iterator.h:1014
Base::FragmentConstIterator FragmentConstIterator
Fragment const iterator definition.
Definition: tile_iterator.h:907
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:399
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:916
CUTLASS_HOST_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:699
CUTLASS_HOST_DEVICE int initialize()
Gotta have this.
Definition: tile_iterator.h:321
Kind
Definition: load_store.h:48
CUTLASS_HOST_DEVICE RegularTilePredicateFunctor(Coord< 3 > _bounds)
Constructs a predicate functor given the bounds of a tensor.
Definition: tile_iterator.h:120
CUTLASS_HOST_DEVICE TileLoadIterator()
Default constructor.
Definition: tile_iterator.h:640
Definition: load_store.h:40
Base::Params BaseParams
IteratorBase parameters.
Definition: tile_iterator.h:479
Params params
Parameters structure.
Definition: tile_iterator.h:592
FragmentConstIterator< Fragment, Iterations, AccessType > FragmentConstIterator
The fragment const iterator.
Definition: tile_iterator.h:201
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:552
CUTLASS_HOST_DEVICE TileLoadIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: tile_iterator.h:713
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: tile_iterator.h:686
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:898
Definition: tile_iterator.h:482
Definition: tile_iterator.h:134
Iterations_ Iterations
Number of accesses performed.
Definition: tile_iterator.h:86
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:183
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:771
static int const kAccessSize
Access size.
Definition: tile_iterator.h:105
Fragment< Scalar, ShapeCount< Tile >::kCount, kFragmentSize > Storage
The storage.
Definition: tile_iterator.h:194
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:880
Delta_ Delta
Number of steps between accesses along each dimension.
Definition: tile_iterator.h:83
Defines abstractions for efficiently loading and storing vectors to memory.
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:422
Base::FragmentShape FragmentShape
Fragment type.
Definition: tile_iterator.h:455
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:868
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:1079
CUTLASS_HOST_DEVICE void inc_advance()
Increment in the next dimension.
Definition: tile_iterator.h:1099
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Params(Scalar *ptr)
Definition: tile_iterator.h:939
CUTLASS_HOST_DEVICE TileStoreIterator(Params const &_params, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile store iterator.
Definition: tile_iterator.h:1066
Index inc_h
Definition: tile_iterator.h:224
CUTLASS_HOST_DEVICE Params()
Constructs params.
Definition: tile_iterator.h:235
Definition: vector.h:62
TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > Base
Base class.
Definition: tile_iterator.h:853
Definition: load_store.h:60
Base::Traits Traits
concept TileTraits
Definition: tile_iterator.h:856
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
CUTLASS_HOST_DEVICE Params(Scalar const *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:505
long long stride_d
Definition: tile_iterator.h:219
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:904
Specifies dimension in which post-increment accesses advance.
Definition: tile_iterator.h:64
CUTLASS_HOST_DEVICE Params(TensorRef const &ref)
Constructs with a CompactTensorRef<>
Definition: tile_iterator.h:509
CUTLASS_HOST_DEVICE Params(Scalar const *ptr, long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Initialize params to access storage object.
Definition: tile_iterator.h:515
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
CUTLASS_HOST_DEVICE TileStoreIterator()
Default constructor.
Definition: tile_iterator.h:1062
Definition: load_store.h:48
CUTLASS_HOST_DEVICE int initialize(TensorRef const &ref)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:537
Base::Scalar Scalar
Scalar element.
Definition: tile_iterator.h:859
Defines a 1D vector of elements held in the registers of each thread.
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
CUTLASS_HOST_DEVICE Params(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w)
Initialize params to access storage object.
Definition: tile_iterator.h:530
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:493
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:1141
Definition: tile_iterator.h:65
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:428
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:152
CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:1184
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:1017
Base::Iterations Iterations
Iterations.
Definition: tile_iterator.h:889
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:874
Functor computing a predicate given the logical position of an access.
Definition: tile_iterator.h:112
Traits::Tile Tile
Tile shape.
Definition: tile_iterator.h:170
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
Adds a raw offset to the pointer.
Definition: tile_iterator.h:1124
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:185
Parameters.
Definition: tile_iterator.h:491
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
Base::PredicateVector PredicateVector
Default predicate mask type.
Definition: tile_iterator.h:473
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d)
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1255
Base::AccessType AccessType
Memory access type.
Definition: tile_iterator.h:458
Base::Skew Skew
Skew quantity.
Definition: tile_iterator.h:440
Base::Delta Delta
Delta.
Definition: tile_iterator.h:886
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:560
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:1093
CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const
Stores a fragment without advancing the iterator.
Definition: tile_iterator.h:1177
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:680
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:778
Base::ThreadOffset ThreadOffset
ThreadOffset functor.
Definition: tile_iterator.h:892
Definition: tile_iterator.h:65
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:431
CUTLASS_HOST_DEVICE void inc_h()
Increment in the H dimension.
Definition: tile_iterator.h:677
static CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &predicate_func, Coord< 3 > const &offset)
Initializes a predicate vector.
Definition: tile_iterator.h:342
Scalar_ Scalar
Scalar element.
Definition: tile_iterator.h:149
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: tile_iterator.h:595
Coord< 3 > bounds
Dimensions of the bounding volume.
Definition: tile_iterator.h:116
Traits::Delta Delta
Distance along each dimension.
Definition: tile_iterator.h:173
static int const kFragmentSize
The size of storage needed per fragment.
Definition: tile_iterator.h:191
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment)
Stores a fragment and advances to the next tile.
Definition: tile_iterator.h:1170
CUTLASS_HOST_DEVICE Params()
Definition: tile_iterator.h:935
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr)
Initialize params to access storage object.
Definition: tile_iterator.h:975
Base::FragmentElement FragmentElement
Fragment element.
Definition: tile_iterator.h:862
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:1241
CUTLASS_HOST_DEVICE Params(TensorRef const &ref)
Constructs with a CompactTensorRef<>
Definition: tile_iterator.h:943
CUTLASS_HOST_DEVICE void inc_d()
Increment in the D dimension.
Definition: tile_iterator.h:674
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:764
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
Parameters to the iterator.
Definition: tile_iterator.h:213
CUTLASS_HOST_DEVICE TileStoreIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: tile_iterator.h:1116
CUTLASS_HOST_DEVICE void inc_stage()
Increment the stage.
Definition: tile_iterator.h:1102
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:467
CUTLASS_HOST_DEVICE bool operator()(Coord< 3 > iteration, Coord< 3 > offset) const
Computes the predicate given the logical position of an access.
Definition: tile_iterator.h:124
CUTLASS_HOST_DEVICE Params(Scalar *ptr, long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, Index _inc_advance)
Definition: tile_iterator.h:955
Base::Index Index
Index type.
Definition: tile_iterator.h:437
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:1096
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:865
CUTLASS_HOST_DEVICE int initialize(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w)
Initializes params to access a raw pointer.
Definition: tile_iterator.h:982
int stage
Stage argument enables wrapping after some number of tiles have been loaded.
Definition: tile_iterator.h:598
CUTLASS_HOST_DEVICE int initialize(long long _stride_d, Index _stride_h, Index _stride_w)
Initializes the parameters object from a vector of strides.
Definition: tile_iterator.h:289
CUTLASS_HOST_DEVICE TileLoadIterator(Params const &, Scalar const *ptr, Coord< 3 > const &block_offset=make_Coord(0, 0, 0), ThreadOffset thread_offset_func=ThreadOffset())
Constructs a tile load iterator.
Definition: tile_iterator.h:659
CUTLASS_HOST_DEVICE int initialize(SharedStorage &storage)
Initialize params to access storage object.
Definition: tile_iterator.h:968
Index inc_w
Definition: tile_iterator.h:225
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: tile_iterator.h:1047
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:836
CUTLASS_HOST_DEVICE Params(long long _stride_d, Index _stride_h, Index _stride_w, long long _inc_d, Index _inc_h, Index _inc_w, long long _inc_advance)
Constructs params.
Definition: tile_iterator.h:239
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
Adds a raw offset to the pointer.
Definition: tile_iterator.h:723
Index stride_w
Definition: tile_iterator.h:221
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and advances the iterator to the next tile.
Definition: tile_iterator.h:1205