46 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16) 47 CUTLASS_DEVICE
bool is_zero(half x) {
return reinterpret_cast<int16_t&
>(x) == int16_t(0); }
52 template <
typename GemmEpilogueTraits_>
57 typedef typename Traits::Params
Params;
68 typedef typename Traits::Scalar
Scalar;
73 static_assert(Iterations::kD == 1 && Iterations::kC == 1,
"Unsupported 3D/4D shapes");
93 typedef typename Traits::Index
Index;
96 typedef typename GlobalLoadIteratorC::Scalar
ScalarC;
98 typedef typename GlobalStoreIteratorD::Scalar
ScalarD;
110 epilogue_with_or_without_beta<true>(block, accumulators);
112 epilogue_with_or_without_beta<false>(block, accumulators);
116 template <
bool kBetaIsZero_>
125 typename GlobalLoadIteratorC::Fragment fragment_c;
127 typename GlobalTransformerC::OutputFragment transformed_c;
130 for (
int h = 0; h < Iterations::kH; ++h) {
132 int const pointer_offset =
133 ((
params.iterator_d.inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
134 params.iterator_d.inc_advance) *
138 int const predicate_offset =
139 ((
params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
140 params.iterator_d.predicate_inc_advance) *
147 params.iterator_c, bounds, block, pointer_offset, predicate_offset);
154 params.iterator_d, bounds, block, pointer_offset, predicate_offset);
157 for (
int w = 0; w < Iterations::kW; ++w) {
167 int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
170 typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
171 shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
183 typename SharedLoadIteratorD::Fragment fetched_d;
187 typename GlobalTransformerD::InputFragment fragment_d;
190 functor.evaluate(fetched_d, fragment_d);
193 transformer_c.transform(fragment_c, transformed_c);
195 functor.evaluate(fetched_d, transformed_c, fragment_d);
199 typename GlobalTransformerD::OutputFragment transformed_d;
200 transformer_d.transform(fragment_d, transformed_d);
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue.h:98
Traits::SharedStoreIteratorD SharedStoreIteratorD
The iterator to store D in shared memory.
Definition: gemm_epilogue.h:84
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from a shared memory input iterator.
Definition: iterator_access.h:75
Traits::Params Params
The params.
Definition: gemm_epilogue.h:57
Definition: gemm_epilogue.h:53
CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord< 3 > const &block, Accumulators &accumulators)
Definition: gemm_epilogue.h:117
CUTLASS_DEVICE GemmEpilogue(Params const ¶ms_, SharedStorage &shared_storage_, Index m_, Index n_)
Ctor.
Definition: gemm_epilogue.h:101
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm_epilogue.h:59
Traits::GlobalTransformerD GlobalTransformerD
The transformer for D.
Definition: gemm_epilogue.h:80
Traits::OutputTile OutputTile
The output tile.
Definition: gemm_epilogue.h:62
Traits::Accumulators Accumulators
The accumulators.
Definition: gemm_epilogue.h:66
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:60
CUTLASS_DEVICE void shared_load_fence()
The memory fence for shared loads.
Definition: gemm_epilogue.h:209
SharedStorage & shared_storage
The shared storage.
Definition: gemm_epilogue.h:217
GemmEpilogueTraits_ Traits
The traits class.
Definition: gemm_epilogue.h:55
CUTLASS_DEVICE bool is_zero(T x)
Definition: gemm_epilogue.h:42
Params const & params
The params.
Definition: gemm_epilogue.h:215
Traits::SharedLoadIteratorD SharedLoadIteratorD
The iterator to load D in shared memory.
Definition: gemm_epilogue.h:88
Traits::Index Index
The index.
Definition: gemm_epilogue.h:93
Traits::SharedStoreTransformerD SharedStoreTransformerD
The shared store transformer for D.
Definition: gemm_epilogue.h:86
CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment)
Stores a fragment to a shared memory output iterator.
Definition: iterator_access.h:228
Traits::GlobalStoreIteratorD GlobalStoreIteratorD
The iterator for D in global memory.
Definition: gemm_epilogue.h:82
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:48
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment)
Stores a fragment to an output iterator.
Definition: iterator_access.h:193
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue.h:96
Index n
Definition: gemm_epilogue.h:219
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment)
Loads a fragment from an input iterator.
Definition: iterator_access.h:41
Traits::Functor Functor
The functor in charge of the math.
Definition: gemm_epilogue.h:70
Traits::Iterations Iterations
The number of iterations.
Definition: gemm_epilogue.h:64
CUTLASS_DEVICE void epilogue(Coord< 3 > const &block, Accumulators &accumulators)
Execute the epilogue.
Definition: gemm_epilogue.h:108
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...
Copy< typename SharedLoadIteratorD::Fragment > SharedLoadTransformerD
The shared load transformer for D.
Definition: gemm_epilogue.h:90
Traits::Scalar Scalar
The scalar.
Definition: gemm_epilogue.h:68
Defines conversion operations among Fragments of different base type.
Index m
The dimensions of the GEMM.
Definition: gemm_epilogue.h:219
CUTLASS_DEVICE void shared_store_fence()
The memory fence for shared stores.
Definition: gemm_epilogue.h:212
Traits::GlobalTransformerC GlobalTransformerC
The transformer for C.
Definition: gemm_epilogue.h:78
Traits::GlobalLoadIteratorC GlobalLoadIteratorC
We do not support 3D or 4D shapes.
Definition: gemm_epilogue.h:73