Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
shape.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/cutlass.h>
31 
32 namespace cutlass {
33 
35 
63 template <int kD_ = 1, int kH_ = 1, int kW_ = 1, int kC_ = 1>
64 struct Shape {
66  static int const kD = kD_;
68  static int const kH = kH_;
70  static int const kW = kW_;
72  static int const kC = kC_;
73 };
74 
78 template <typename Shape>
79 struct ShapeCount {
81  static int const kWc = Shape::kW * Shape::kC;
83  static int const kHw = Shape::kH * Shape::kW;
85  static int const kHwc = Shape::kH * kWc;
87  static int const kDhw = Shape::kD * kHw;
89  static int const kDhwc = Shape::kD * kHwc;
91  static int const kCount = kDhwc;
92 };
93 
95 
96 template <typename A_, int kScale_>
97 struct ShapeScale {
99 };
100 
102 
103 template <typename A_, typename B_>
104 struct ShapeAdd {
106 };
107 
109 
110 template <typename A_, typename B_>
111 struct ShapeSub {
112  typedef Shape<A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC> Shape;
113 };
114 
116 
117 template <typename A_, typename B_>
118 struct ShapeMul {
120 };
121 
123 
124 template <typename A_, typename B_>
125 struct ShapeDiv {
126  typedef Shape<A_::kD / B_::kD, A_::kH / B_::kH, A_::kW / B_::kW, A_::kC / B_::kC> Shape;
127 };
128 
130 
131 template <typename A_, typename B_>
132 struct ShapeMax {
133  typedef Shape<(A_::kD > B_::kD ? A_::kD : B_::kD),
134  (A_::kH > B_::kH ? A_::kH : B_::kH),
135  (A_::kW > B_::kW ? A_::kW : B_::kW),
136  (A_::kC > B_::kC ? A_::kC : B_::kC)>
138 };
139 
141 
142 template <typename A_, typename B_>
143 struct ShapeMin {
144  typedef Shape<(A_::kD < B_::kD ? A_::kD : B_::kD),
145  (A_::kH < B_::kH ? A_::kH : B_::kH),
146  (A_::kW < B_::kW ? A_::kW : B_::kW),
147  (A_::kC < B_::kC ? A_::kC : B_::kC)>
149 };
150 
152 
153 template <typename Shape_>
154 struct ShapeStrides {
156 };
157 
159 
164 template <typename Shape_>
166  static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
167  // clang-format off
168  return d * Shape_::kH * Shape_::kW * Shape_::kC +
169  h * Shape_::kW * Shape_::kC +
170  w * Shape_::kC +
171  c;
172  // clang-format on
173  }
174 };
175 
177 
184 template <int kSh_, int kSw_, int kSc_>
185 struct ComputeOffsetFromShape<Shape<1, kSh_, kSw_, kSc_> > {
186  static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
187  return h * kSw_ * kSc_ + w * kSc_ + c;
188  }
189 };
190 
192 
198 template <int kSh_, int kSw_>
199 struct ComputeOffsetFromShape<Shape<1, kSh_, kSw_, 1> > {
200  static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * kSw_ + w; }
201 };
202 
204 
209 template <typename Strides_>
211  static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
212  return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
213  }
214 };
215 
217 
224 template <int S_h_, int S_w_, int S_c_>
225 struct ComputeOffsetFromStrides<Shape<1, S_h_, S_w_, S_c_> > {
226  static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
227  return h * S_h_ + w * S_w_ + c * S_c_;
228  }
229 };
230 
232 
238 template <int S_h_, int S_w_>
239 struct ComputeOffsetFromStrides<Shape<1, S_h_, S_w_, 1> > {
240  static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * S_h_ + w * S_w_; }
241 };
242 
244 
251 template <typename Threads_, typename Strides_>
253  static CUTLASS_DEVICE int get() {
254  // Decompose the thread index.
255  int c = threadIdx.x % Threads_::kC;
256  int w = threadIdx.x / Threads_::kC % Threads_::kW;
257  int h = threadIdx.x / Threads_::kC / Threads_::kW % Threads_::kH;
258  int d = threadIdx.x / Threads_::kC / Threads_::kW / Threads_::kH;
259 
260  // Compute the offset.
261  return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
262  }
263 };
264 
266 
269 template <int T_h_, int T_w_, int T_c_, int S_h_, int S_w_, int S_c_>
270 struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, T_c_>, Shape<1, S_h_, S_w_, S_c_> > {
271  static CUTLASS_DEVICE int get() {
272  // Decompose the thread index.
273  int c = threadIdx.x % T_c_;
274  int w = threadIdx.x / T_c_ % T_w_;
275  int h = threadIdx.x / T_c_ / T_w_ % T_h_;
276 
277  // Compute the offset.
278  return h * S_h_ + w * S_w_ + c * S_c_;
279  }
280 };
281 
283 
287 template <int T_h_, int T_w_, int S_h_, int S_w_>
288 struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, 1>, Shape<1, S_h_, S_w_, 1> > {
289  static CUTLASS_DEVICE int get() {
290  // Decompose the thread index.
291  int w = threadIdx.x % T_w_;
292  int h = threadIdx.x / T_w_;
293 
294  // Compute the offset.
295  return h * S_h_ + w * S_w_;
296  }
297 };
298 
300 
301 } // namespace cutlass
Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_.
Definition: shape.h:252
static int const kWc
The number of elements per row.
Definition: shape.h:81
Definition: convert.h:33
Shape< A_::kD+B_::kD, A_::kH+B_::kH, A_::kW+B_::kW, A_::kC+B_::kC > Shape
Definition: shape.h:105
Shape< A_::kD *kScale_, A_::kH *kScale_, A_::kW *kScale_, A_::kC *kScale_ > Shape
Definition: shape.h:98
Shape< Shape_::kH *Shape_::kW *Shape_::kC, Shape_::kW *Shape_::kC, Shape_::kC, 1 > Shape
Definition: shape.h:155
Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
Definition: shape.h:119
Shape< A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC > Shape
Definition: shape.h:112
Definition: shape.h:111
static int const kH
The height of the cube.
Definition: shape.h:68
static int const kC
The number of scalars per element.
Definition: shape.h:72
Definition: shape.h:97
Compute the offset for the given coordinates in a cube.
Definition: shape.h:165
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
static int const kDhw
The number of pixels per cube.
Definition: shape.h:87
Definition: shape.h:118
Definition: shape.h:125
Compute the offset for the given coordinates in a cube.
Definition: shape.h:210
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: shape.h:132
Definition: shape.h:104
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
static int const kDhwc
The number of elements in the 4D space.
Definition: shape.h:89
static int const kW
The width of the cube.
Definition: shape.h:70
Definition: shape.h:143
static int const kHw
The number of pixels per image.
Definition: shape.h:83
static int const kD
The depth of the cube.
Definition: shape.h:66
Definition: shape.h:154
Shape<(A_::kD > B_::kD ? A_::kD :B_::kD),(A_::kH > B_::kH ? A_::kH :B_::kH),(A_::kW > B_::kW ? A_::kW :B_::kW),(A_::kC > B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:137
Basic include for CUTLASS macros.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:148
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
static int const kHwc
The number of elements per image.
Definition: shape.h:85