30 #if !defined(__CUDACC_RTC__) 42 template <
typename Gemm_>
45 __shared__
typename Gemm_::SharedStorage shared_storage;
48 Gemm_ gemm(params, shared_storage);
55 template <
typename Scalar_,
typename Index_ =
int>
81 template <
typename GemmTraits_>
97 typedef typename Traits::Epilogue::ScalarC
ScalarC;
99 typedef typename Traits::Epilogue::ScalarD
ScalarD;
101 typedef typename Traits::Index
Index;
104 static int const kThreads = Traits::GemmConfig::kThreads;
127 desc.
d_a =
reinterpret_cast<void const*
>(d_a);
129 desc.
d_b =
reinterpret_cast<void const*
>(d_b);
131 desc.
d_c =
reinterpret_cast<void const*
>(d_c);
133 desc.
d_d =
reinterpret_cast<void*
>(d_d);
135 return Traits::Params::initialize(desc);
139 #if !defined(__CUDACC_RTC__) 140 static __host__ cudaError_t
launch(Params
const&
params,
142 cudaStream_t stream = cudaStreamDefault) {
145 grid.x = (
params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
146 grid.y = (
params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
153 void const* params_ =
reinterpret_cast<void const*
>(&
params);
155 return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>),
158 const_cast<void**>(¶ms_),
164 static __host__ cudaError_t
launch(CUfunction kernel,
166 CUstream stream = CU_STREAM_LEGACY) {
169 grid.x = (
params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
170 grid.y = (
params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
177 void* params_[] = {
const_cast<void*
>(
reinterpret_cast<void const*
>(&
params))};
181 CUresult result = cuLaunchKernel(
182 kernel, grid.x, grid.y, grid.z, block.x, block.y, block.z, 0, stream, params_, 0);
184 if (result != CUDA_SUCCESS) {
185 return cudaErrorLaunchFailure;
199 typename Traits::BlockSwizzle block_swizzle;
200 dim3 block = block_swizzle.swizzle();
203 block.x *= Traits::OutputTile::kW;
204 block.y *= Traits::OutputTile::kH;
216 typedef typename Traits::MultiplyAdd MultiplyAdd;
219 Index const kUnroll =
static_cast<Index>(MultiplyAdd::AccumulatorsPerWarp::kD);
223 global_stream.residue(
params.k,
true);
227 global_stream.copy();
230 global_stream.commit();
233 Traits::shared_store_fence(
false);
236 int const kUnrollingSteps =
237 MultiplyAdd::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
240 static_assert(kUnrollingSteps >= 2,
"The pipelining assumes at least two steps");
246 shared_load_stream.copy(0);
249 typename MultiplyAdd::Accumulators accumulators;
251 clear.
clear(accumulators);
254 typedef typename Traits::Index
Index;
255 for (
Index outer_k =
params.k - kUnroll; outer_k > -kUnroll; outer_k -= kUnroll) {
257 int const is_residue = outer_k <= kUnroll;
259 global_stream.residue(outer_k);
263 global_stream.copy();
266 for (
int step = 0; step < kUnrollingSteps - 1; ++step) {
268 shared_load_stream.copy(step + 1);
270 shared_load_stream.commit(step);
274 multiply_add.multiply_add(shared_load_stream.fragment_a(step),
275 shared_load_stream.fragment_b(step),
281 Traits::shared_load_fence(
true);
284 global_stream.commit();
287 Traits::shared_store_fence(
true);
290 shared_load_stream.inc_stage();
292 shared_load_stream.copy(0);
294 shared_load_stream.commit(kUnrollingSteps - 1);
298 multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
299 shared_load_stream.fragment_b(kUnrollingSteps - 1),
305 typedef typename Traits::Epilogue Epilogue;
SharedStorage & shared_storage
The shared storage.
Definition: gemm.h:313
Traits::Epilogue::ScalarD ScalarD
The scalar for D.
Definition: gemm.h:99
Scalar_ beta
Definition: gemm.h:60
Index_ k
Definition: gemm.h:58
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm.h:88
The params.
Definition: gemm.h:107
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
Params const & params
The params.
Definition: gemm.h:311
Index_ m
The dimensions of the GEMM.
Definition: gemm.h:58
Traits::Epilogue::ScalarC ScalarC
The scalar for C.
Definition: gemm.h:97
Index_ ldb
The stride for B.
Definition: gemm.h:68
CUTLASS_DEVICE void multiply_add()
Do the GEMM.
Definition: gemm.h:197
GemmTraits_ Traits
The traits.
Definition: gemm.h:86
Traits::Epilogue::Scalar ScalarEpilogue
The scalar in the epilogue.
Definition: gemm.h:95
Index_ n
Definition: gemm.h:58
Traits::ScalarB ScalarB
The scalar for B.
Definition: gemm.h:93
Definition: clear_accumulators.h:38
void * d_d
The destination matrix D.
Definition: gemm.h:74
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:60
static __host__ cudaError_t launch(CUfunction kernel, Params const ¶ms, CUstream stream=CU_STREAM_LEGACY)
Launch the kernel.
Definition: gemm.h:164
void const * d_a
The source matrix A.
Definition: gemm.h:62
__global__ void gemm_kernel(typename Gemm_::Params params)
Definition: gemm.h:43
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, ScalarEpilogue alpha, ScalarA const *d_a, Index lda, ScalarB const *d_b, Index ldb, ScalarEpilogue beta, ScalarC const *d_c, Index ldc, ScalarD *d_d, Index ldd)
Definition: gemm.h:108
Index_ lda
The stride for A.
Definition: gemm.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Gemm< GemmTraits_ > This_
This class.
Definition: gemm.h:84
Index_ ldc
The stride for C.
Definition: gemm.h:72
CUTLASS_DEVICE Gemm(Params const ¶ms_, SharedStorage &shared_storage_)
Ctor.
Definition: gemm.h:193
Index_ ldd
The stride for D.
Definition: gemm.h:76
Traits::ScalarA ScalarA
The scalar for A.
Definition: gemm.h:91
CUTLASS_DEVICE void clear(Fragment_ &fragment)
Clear the fragment.
Definition: clear_accumulators.h:47
static int const kThreads
The number of threads.
Definition: gemm.h:104
Scalar_ alpha
The alpha/beta scaling values.
Definition: gemm.h:60
void const * d_c
The source matrix C.
Definition: gemm.h:70
static __host__ cudaError_t launch(Params const ¶ms, cudaStream_t stream=cudaStreamDefault)
Launch the kernel.
Definition: gemm.h:141
Traits::Index Index
The index.
Definition: gemm.h:101
void const * d_b
The source matrix B.
Definition: gemm.h:66