63 template <
int kD_ = 1,
int kH_ = 1,
int kW_ = 1,
int kC_ = 1>
66 static int const kD = kD_;
68 static int const kH = kH_;
70 static int const kW = kW_;
72 static int const kC = kC_;
78 template <
typename Shape>
96 template <
typename A_,
int kScale_>
103 template <
typename A_,
typename B_>
110 template <
typename A_,
typename B_>
112 typedef Shape<A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC>
Shape;
117 template <
typename A_,
typename B_>
124 template <
typename A_,
typename B_>
126 typedef Shape<A_::kD / B_::kD, A_::kH / B_::kH, A_::kW / B_::kW, A_::kC / B_::kC>
Shape;
131 template <
typename A_,
typename B_>
134 (A_::kH > B_::kH ? A_::kH : B_::kH),
135 (A_::kW > B_::kW ? A_::kW : B_::kW),
136 (A_::kC > B_::kC ? A_::kC : B_::kC)>
142 template <
typename A_,
typename B_>
144 typedef Shape<(A_::kD < B_::kD ? A_::kD : B_::kD),
145 (A_::kH < B_::kH ? A_::kH : B_::kH),
146 (A_::kW < B_::kW ? A_::kW : B_::kW),
147 (A_::kC < B_::kC ? A_::kC : B_::kC)>
153 template <
typename Shape_>
164 template <
typename Shape_>
166 static CUTLASS_DEVICE
int get(
int d,
int h,
int w,
int c) {
168 return d * Shape_::kH * Shape_::kW * Shape_::kC +
169 h * Shape_::kW * Shape_::kC +
184 template <
int kSh_,
int kSw_,
int kSc_>
186 static CUTLASS_DEVICE
int get(
int d,
int h,
int w,
int c) {
187 return h * kSw_ * kSc_ + w * kSc_ + c;
198 template <
int kSh_,
int kSw_>
200 static CUTLASS_DEVICE
int get(
int d,
int h,
int w,
int c) {
return h * kSw_ + w; }
209 template <
typename Str
ides_>
211 static CUTLASS_DEVICE
int get(
int d,
int h,
int w,
int c) {
212 return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
224 template <
int S_h_,
int S_w_,
int S_c_>
226 static CUTLASS_DEVICE
int get(
int d,
int h,
int w,
int c) {
227 return h * S_h_ + w * S_w_ + c * S_c_;
238 template <
int S_h_,
int S_w_>
240 static CUTLASS_DEVICE
int get(
int d,
int h,
int w,
int c) {
return h * S_h_ + w * S_w_; }
251 template <
typename Threads_,
typename Str
ides_>
253 static CUTLASS_DEVICE
int get() {
255 int c = threadIdx.x % Threads_::kC;
256 int w = threadIdx.x / Threads_::kC % Threads_::kW;
257 int h = threadIdx.x / Threads_::kC / Threads_::kW % Threads_::kH;
258 int d = threadIdx.x / Threads_::kC / Threads_::kW / Threads_::kH;
261 return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
269 template <
int T_h_,
int T_w_,
int T_c_,
int S_h_,
int S_w_,
int S_c_>
271 static CUTLASS_DEVICE
int get() {
273 int c = threadIdx.x % T_c_;
274 int w = threadIdx.x / T_c_ % T_w_;
275 int h = threadIdx.x / T_c_ / T_w_ % T_h_;
278 return h * S_h_ + w * S_w_ + c * S_c_;
287 template <
int T_h_,
int T_w_,
int S_h_,
int S_w_>
289 static CUTLASS_DEVICE
int get() {
291 int w = threadIdx.x % T_w_;
292 int h = threadIdx.x / T_w_;
295 return h * S_h_ + w * S_w_;
Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_.
Definition: shape.h:252
static int const kWc
The number of elements per row.
Definition: shape.h:81
Shape< A_::kD+B_::kD, A_::kH+B_::kH, A_::kW+B_::kW, A_::kC+B_::kC > Shape
Definition: shape.h:105
Shape< A_::kD *kScale_, A_::kH *kScale_, A_::kW *kScale_, A_::kC *kScale_ > Shape
Definition: shape.h:98
Shape< Shape_::kH *Shape_::kW *Shape_::kC, Shape_::kW *Shape_::kC, Shape_::kC, 1 > Shape
Definition: shape.h:155
Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
Definition: shape.h:119
Shape< A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC > Shape
Definition: shape.h:112
static int const kH
The height of the cube.
Definition: shape.h:68
static int const kC
The number of scalars per element.
Definition: shape.h:72
Compute the offset for the given coordinates in a cube.
Definition: shape.h:165
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
static int const kDhw
The number of pixels per cube.
Definition: shape.h:87
Compute the offset for the given coordinates in a cube.
Definition: shape.h:210
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
static int const kDhwc
The number of elements in the 4D space.
Definition: shape.h:89
static int const kW
The width of the cube.
Definition: shape.h:70
static int const kHw
The number of pixels per image.
Definition: shape.h:83
static int const kD
The depth of the cube.
Definition: shape.h:66
Shape<(A_::kD > B_::kD ? A_::kD :B_::kD),(A_::kH > B_::kH ? A_::kH :B_::kH),(A_::kW > B_::kW ? A_::kW :B_::kW),(A_::kC > B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:137
Basic include for CUTLASS macros.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:148
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
static int const kHwc
The number of elements per image.
Definition: shape.h:85