Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
fragment.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include <cutlass/shape.h>
34 #include <cutlass/vector.h>
35 
36 namespace cutlass {
37 
39 
56 
73 
75 template <int kAlignment_>
76 struct StorageType {
77  typedef uint64_t Type;
78 };
79 template <>
80 struct StorageType<4> {
81  typedef uint32_t Type;
82 };
83 template <>
84 struct StorageType<2> {
85  typedef uint16_t Type;
86 };
87 template <>
88 struct StorageType<1> {
89  typedef uint8_t Type;
90 };
91 
93 
98 template <typename Element_, int kElements_, size_t kAlignment_ = 16>
99 struct Fragment : public AlignedStruct<kAlignment_> {
101  static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
103  static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
104 
108  typedef Element_ Element;
110  static int const kElements = kElements_;
111 
113  CUTLASS_DEVICE void clear() {
114  // Avoid element-wise access for sub 32b element type
115  if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
116  uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
117  for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
118  ptr[i] = uint64_t(0);
119  }
120  } else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
121  uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
122  for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
123  ptr[i] = uint32_t(0);
124  }
125  } else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
126  uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
127  for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
128  ptr[i] = uint16_t(0);
129  }
130  } else {
131  for (int i = 0; i < kElements; ++i) {
132  storage[i] = 0;
133  }
134  }
135  }
136 
138  CUTLASS_DEVICE Element& operator[](int i) {
139  assert(i < kElements_);
140  return reinterpret_cast<Element*>(storage)[i];
141  }
142 
144  CUTLASS_DEVICE Element const& operator[](int i) const {
145  assert(i < kElements_);
146  return reinterpret_cast<Element const*>(storage)[i];
147  }
148 
149  private:
152 
154  static int const kStorageCount =
155  (sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
157  StorageType storage[kStorageCount];
158 
160  static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
161 };
162 
164 
169 template <typename Fragment_, typename Iterations_, typename AccessType_>
174  typedef Fragment_ Fragment;
176  typedef Iterations_ Iterations;
178  typedef AccessType_ AccessType;
179 
181  typedef typename Fragment::Element Element;
183  static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
188 
190  template <typename OtherFragment_>
191  CUTLASS_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
192  : pointer(reinterpret_cast<Element*>(&fragment[offset])) {
193  static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
194  }
195 
197  CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
198  int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
199  return reinterpret_cast<AccessType const&>(pointer[imm]);
200  }
201 
203  CUTLASS_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
204  int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
205  return reinterpret_cast<AccessType&>(pointer[imm]);
206  }
207 
209  CUTLASS_DEVICE AccessType const& operator[](int i) const {
210  return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
211  }
212 
214  CUTLASS_DEVICE AccessType& operator[](int i) {
215  return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
216  }
217 
219  CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
220 
223 };
224 
226 
227 template <typename Fragment_, typename Iterations_, typename AccessType_>
232  typedef Fragment_ Fragment;
234  typedef Iterations_ Iterations;
236  typedef AccessType_ AccessType;
237 
239  typedef typename Fragment::Element Element;
241  static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
246 
248  template <typename OtherFragment_>
249  CUTLASS_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
250  : pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
251  static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
252  }
254  CUTLASS_DEVICE FragmentConstIterator(
256  : pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
257 
259  CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
260  int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
261  return reinterpret_cast<AccessType const&>(pointer[imm]);
262  }
263 
265  CUTLASS_DEVICE AccessType const& operator[](int i) const {
266  return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
267  }
268 
270  CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
271 
273  Element const* pointer;
274 };
275 
277 
278 } // namespace cutlass
CUTLASS_DEVICE void clear()
Clear a fragment.
Definition: fragment.h:113
Definition: convert.h:33
CUTLASS_DEVICE Element & operator[](int i)
The accessor.
Definition: fragment.h:138
CUTLASS_DEVICE AccessType & at(int d, int h, int w, int c=0)
The accessor.
Definition: fragment.h:203
Definition: vector.h:41
Definition: fragment.h:228
CUTLASS_DEVICE AccessType const & operator[](int i) const
The accessor.
Definition: fragment.h:265
Shape< Shape_::kH *Shape_::kW *Shape_::kC, Shape_::kW *Shape_::kC, Shape_::kC, 1 > Shape
Definition: shape.h:155
A template defining Fragment Concept.
Definition: fragment.h:99
Fragment::Element Element
The element.
Definition: fragment.h:181
static int const kElementsPerAccess
The number of elements per access.
Definition: fragment.h:241
Fragment_ Fragment
The fragment.
Definition: fragment.h:174
Fragment_ Fragment
The fragment.
Definition: fragment.h:232
CUTLASS_DEVICE AccessType & operator[](int i)
The accessor.
Definition: fragment.h:214
Fragment::Element Element
The element.
Definition: fragment.h:239
ShapeStrides< FragmentShape >::Shape IterationsStrides
The linear strides for iterations.
Definition: fragment.h:245
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: fragment.h:270
CUTLASS_DEVICE FragmentIterator(OtherFragment_ &fragment, int offset=0)
Ctor.
Definition: fragment.h:191
Fragment< Element_, kElements_ > This_
Make sure the alignment makes sense wrt the size of elements.
Definition: fragment.h:101
FragmentIterator< Fragment_, Iterations_, AccessType_ > This_
This class.
Definition: fragment.h:172
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:243
Math utilities.
Definition: fragment.h:76
uint32_t Type
Definition: fragment.h:81
uint8_t Type
Definition: fragment.h:89
static CUTLASS_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:211
Element * pointer
The pointer.
Definition: fragment.h:222
AccessType_ AccessType
The access type.
Definition: fragment.h:236
Definition: shape.h:118
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:185
A template defining Fragment Iterator Concept.
Definition: fragment.h:170
static int const kElements
The number of elements.
Definition: fragment.h:110
CUTLASS_DEVICE Element const & operator[](int i) const
The accessor.
Definition: fragment.h:144
Iterations_ Iterations
The number of iterations.
Definition: fragment.h:234
#define static_assert(__e, __m)
Definition: platform.h:145
Iterations_ Iterations
The number of iterations.
Definition: fragment.h:176
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
CUTLASS_DEVICE AccessType const & at(int d, int h, int w, int c=0) const
The accessor.
Definition: fragment.h:259
Element_ Element
The element.
Definition: fragment.h:108
FragmentIterator< Fragment_, Iterations_, AccessType_ > This_
This class.
Definition: fragment.h:230
CUTLASS_DEVICE AccessType const & operator[](int i) const
The accessor.
Definition: fragment.h:209
uint16_t Type
Definition: fragment.h:85
Defines a 1D vector of elements held in the registers of each thread.
CUTLASS_DEVICE FragmentConstIterator(FragmentIterator< Fragment_, Iterations_, AccessType_ > const &rhs_)
Create from non-constant FragmentIterator.
Definition: fragment.h:254
static int const kElementsPerAccess
The number of elements per access.
Definition: fragment.h:183
ShapeStrides< FragmentShape >::Shape Strides
The linear strides for iterations.
Definition: fragment.h:187
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
AccessType_ AccessType
The access type.
Definition: fragment.h:178
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: fragment.h:219
uint64_t Type
Definition: fragment.h:77
Definition: cutlass_math.h:45
CUTLASS_DEVICE FragmentConstIterator(OtherFragment_ &fragment, int offset=0)
Ctor.
Definition: fragment.h:249
CUTLASS_DEVICE AccessType const & at(int d, int h, int w, int c=0) const
The accessor.
Definition: fragment.h:197
Element const * pointer
The pointer.
Definition: fragment.h:273