2023-03-29 22:42:40 +08:00
|
|
|
/***************************************************************************************************
|
2024-01-17 03:37:22 +08:00
|
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2023-03-29 22:42:40 +08:00
|
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
*
|
|
|
|
* Redistribution and use in source and binary forms, with or without
|
|
|
|
* modification, are permitted provided that the following conditions are met:
|
|
|
|
*
|
|
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
|
|
* list of conditions and the following disclaimer.
|
|
|
|
*
|
|
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
|
|
* and/or other materials provided with the distribution.
|
|
|
|
*
|
|
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
|
|
* contributors may be used to endorse or promote products derived from
|
|
|
|
* this software without specific prior written permission.
|
|
|
|
*
|
|
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*
|
|
|
|
**************************************************************************************************/
|
|
|
|
/*! \file
|
|
|
|
\brief Example of a GETT targeting Hopper tensor cores using the CUTLASS 3.x API.
|
|
|
|
|
|
|
|
CUTLASS has long provided implementations of Generalized Matrix times Matrix (GEMM) kernels.
|
|
|
|
However, a plethora of workloads compute on higher ranked tensors. Products of such tensors,
|
|
|
|
called tensor contractions, can be executed as multiple batched GEMMs, however, they can be
|
|
|
|
further accelerated with kernels that natively operate on these higher ranked tensors to
|
|
|
|
perform Generalized Tensor times Tensor contractions (GETT). CuTe's hierarchical layouts
|
|
|
|
and CUTLASS 3.0's unified micro-kernels make implementation of GETTs trivial. In this example,
|
|
|
|
we show how CUTLASS 3.0, CuTe, and Hopper's TMA feature together can accelerate GETTs while
|
|
|
|
making the process of authoring custom GETT kernels easier than ever before.
|
|
|
|
|
|
|
|
The modes of a tensor that participate in a GETT can be fundamentally grouped into four
|
|
|
|
semantic categories. The contraction modes (or K-modes) only appear in the A and B (left and right)
|
|
|
|
inputs but not in the C output tensor. Row modes (or M-modes) only appear in the left
|
|
|
|
input tensor (A) and the output tensor (C). Column modes (or N-modes) only appear in the
|
|
|
|
right (B) input tensor and the output tensor (C). Batch modes (or L-modes) appear in all
|
|
|
|
input and output tensors. If we fold the many modes of a tensor contraction into these four
|
|
|
|
categories, it would allow us to represent the input and output tensors as rank-3 "matrices"
|
|
|
|
that can be computed upon as if we were computing a batched GEMM!
|
|
|
|
|
|
|
|
This is exactly what CuTe's hierarchical layout representation allows us to do! Instead of having
|
|
|
|
simple integers as strides for these four modes, we can have nested strides for each of these
|
|
|
|
semantic categories that themselves have multiple modes within them -- multi-mode strides!
|
|
|
|
In CUTLASS 3.0, all one has to do to take advantage of this capability is to substitute the
|
|
|
|
required multi-mode strides instead of the default ones provided by gemm::detail::TagToStrideX.
|
|
|
|
|
|
|
|
In the following example, we illustrate how every Hopper GEMM in CUTLASS 3.0 is a GETT in disguise.
|
|
|
|
We begin by defining the four modes detailed above as Row, Col (column), Red (reduction), and
|
|
|
|
Bat (batch) strides, which we then nest for each of the in/out tensors to create our rank-3 stride
|
|
|
|
tuples. Note that although we do not define the problem shape type explicitely, it too remains a
|
|
|
|
rank-4 shape tuple just like any other batched GEMM, but instead with multi-mode shapes for each
|
|
|
|
of the four corresponding multi-modes within it. After this, the same CollectiveMma and
|
|
|
|
CollectiveBuilder we describe in examples 50 and 49 are used to create our kernel type. Nothing
|
|
|
|
else changes from a user's point of view. Note that multi-mode strides do not affect our
|
|
|
|
specializations in any way -- the lexical spelling of our kernels remains the same. The
|
|
|
|
only difference between a CUTLASS 3 batched GEMM and GETT are the instaced CuTe Layouts.
|
|
|
|
|
|
|
|
CollectiveBuilders rely on detecting the static-1 in the stride tuples to determine the major mode,
|
|
|
|
which is what the example demonstrates. However, it is possible to have all modes be dynamic as well
|
|
|
|
if the user assembles a CollectiveMma manually and ensures that the runtime strides are compatible
|
|
|
|
with the static micro-kernel of the collective (TiledMma, TiledCopy, and smem layouts). On the other
|
|
|
|
hand, a user can have more than one static stride too (which need not correspond to the major mode).
|
|
|
|
|
|
|
|
In particular, this example demonstrates a GETT where the 0th M-mode (M0) in A and the 0th K-mode (K0)
|
|
|
|
in B are major. All other combinations of major modes are supported, with the exception of mixed
|
|
|
|
K-major scenarios where both A and B are K-major (e.g. K0 is major in A but K1 is major in B).
|
|
|
|
NVIDIA Hopper architecture's TMA feature makes the predictaion required to implement these complicated
|
|
|
|
kernels trivial, as it is all handled by TMA itself without requiring any programmer effort.
|
|
|
|
|
|
|
|
Example executions, where the stride order defines the major-order (major on the left):
|
|
|
|
51_hopper_gett --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096
|
|
|
|
51_hopper_gett --modeC=l,m,n --modeA=m,l,k --modeB=k,n,l --extents=m:128,n:128,k:128,l:64
|
|
|
|
51_hopper_gett --modeC=m,a,b,p,q,n,l --modeA=m,l,b,k,a --modeB=k,n,p,q,l --extents=m:32,a:32,b:3,n:128,k:128,l:4,p:3,q:3
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include "gett_kernel.cuh"
|
|
|
|
#include "thrust/host_vector.h"
|
|
|
|
#include "thrust/device_vector.h"
|
|
|
|
|
|
|
|
#include "cutlass/cutlass.h"
|
|
|
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
|
|
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
|
|
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
|
|
|
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
|
|
|
|
|
|
|
#include "cutlass/util/gett_commandline.hpp"
|
|
|
|
#include "cutlass/util/reference/device/gett.hpp"
|
|
|
|
#include "cutlass/util/reference/device/tensor_compare.h"
|
|
|
|
#include "cutlass/util/print_error.hpp"
|
|
|
|
|
|
|
|
namespace example {
|
|
|
|
|
|
|
|
// Returns true if the left-most value in the tuple is statically known to be 1
|
|
|
|
template<class Stride>
|
|
|
|
constexpr bool
|
|
|
|
is_left_major() {
|
|
|
|
// Account for stride types with and without batch mode and batch modes with static zero stride
|
|
|
|
return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Same as cute::make_int_tuple but inserts a major stride (Int<1>) for the leftmost mode if required
|
|
|
|
template <int Rank, bool IsMajor, class Indexable>
|
|
|
|
static constexpr
|
|
|
|
auto
|
|
|
|
make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) {
|
|
|
|
static_assert(Rank > 1);
|
|
|
|
if constexpr (IsMajor) {
|
|
|
|
return cute::transform(cute::make_seq<Rank>{}, [&](auto i) {
|
|
|
|
if constexpr (i == 0) {
|
|
|
|
return cute::Int<1>{};
|
|
|
|
}
|
|
|
|
else {
|
|
|
|
return i < n ? t[i] : init_default;
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
else {
|
|
|
|
return cute::make_int_tuple<Rank>(t, n, init_default);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace example
|
|
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
int
|
|
|
|
main(int argc, char const* argv[]) {
|
|
|
|
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
|
|
|
using namespace cute;
|
|
|
|
|
|
|
|
if (argc != 5) {
|
|
|
|
std::cout << "Number of command line args must be 4.\n";
|
|
|
|
cutlass::GettCommandLine::print_usage();
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Define the stride types for A, B, C, and D
|
|
|
|
//
|
|
|
|
|
|
|
|
// Stride for A (left input). If reduction mode is major, same must be major in B
|
|
|
|
// For this example, M0 is major in A.
|
|
|
|
using RowModeStridesA = cute::Stride<cute::Int<1>, int64_t, int64_t, int64_t>;
|
|
|
|
using RedModeStridesA = cute::Stride<int64_t, int64_t, int64_t>;
|
|
|
|
using BatModeStridesA = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
|
|
|
|
|
|
|
// Stride for B (right input). If reduction mode is major, same must be major in A
|
|
|
|
// For this example, K0 is major in B.
|
|
|
|
using ColModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
|
|
|
using RedModeStridesB = cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
|
|
|
using BatModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
|
|
|
|
|
|
|
// Strides for output, which can all be dynamic.
|
|
|
|
using RowModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
|
|
|
using ColModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
|
|
|
using BatModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
|
|
|
|
|
|
|
// Assmble our rank-3 multi-mode strides for the in/out tensors
|
|
|
|
using StrideA = cute::Stride<RowModeStridesA, RedModeStridesA, BatModeStridesA>;
|
|
|
|
using StrideB = cute::Stride<ColModeStridesB, RedModeStridesB, BatModeStridesB>;
|
|
|
|
using StrideC = cute::Stride<RowModeStridesC, ColModeStridesC, BatModeStridesC>;
|
|
|
|
|
|
|
|
// Note: C and D share strides here for simplicity.
|
|
|
|
// In general, they need not have the same layout.
|
|
|
|
using StrideD = StrideC;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Define element types for tensors and intermediate values
|
|
|
|
//
|
|
|
|
using ElementA = cutlass::half_t;
|
|
|
|
using ElementB = cutlass::half_t;
|
|
|
|
using ElementC = cutlass::half_t;
|
|
|
|
using ElementD = float;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
using ElementEpilogue = float;
|
|
|
|
|
|
|
|
// The following constexpr values set the max number of modes in each MNKL mode
|
2023-12-09 03:42:12 +08:00
|
|
|
constexpr int MaxRank_M = cute::rank(RowModeStridesA{}); // Max row modes
|
|
|
|
constexpr int MaxRank_N = cute::rank(ColModeStridesB{}); // Max column modes
|
|
|
|
constexpr int MaxRank_K = cute::rank(RedModeStridesA{}); // Max contraction modes
|
|
|
|
constexpr int MaxRank_L = cute::rank(BatModeStridesA{}); // Max batch modes
|
|
|
|
static_assert(cute::rank(RowModeStridesA{}) == cute::rank(RowModeStridesC{}));
|
|
|
|
static_assert(cute::rank(ColModeStridesB{}) == cute::rank(RowModeStridesC{}));
|
|
|
|
static_assert(cute::rank(RedModeStridesA{}) == cute::rank(RedModeStridesB{}));
|
|
|
|
static_assert(cute::rank(BatModeStridesA{}) == cute::rank(BatModeStridesC{}));
|
|
|
|
static_assert(cute::rank(BatModeStridesB{}) == cute::rank(BatModeStridesC{}));
|
2023-03-29 22:42:40 +08:00
|
|
|
|
|
|
|
// Parse command line to get modes, extents, and strides
|
|
|
|
cutlass::GettCommandLine cmd;
|
|
|
|
auto parsed_args = cmd.parse(argc, argv, true);
|
|
|
|
|
|
|
|
auto& m = parsed_args.M;
|
|
|
|
auto& ldAm = parsed_args.ldAm;
|
|
|
|
auto& ldCm = parsed_args.ldCm;
|
|
|
|
int rank_m = int(m.size());
|
|
|
|
|
|
|
|
auto& n = parsed_args.N;
|
|
|
|
auto& ldBn = parsed_args.ldBn;
|
|
|
|
auto& ldCn = parsed_args.ldCn;
|
|
|
|
int rank_n = int(n.size());
|
|
|
|
|
|
|
|
auto& k = parsed_args.K;
|
|
|
|
auto& ldAk = parsed_args.ldAk;
|
|
|
|
auto& ldBk = parsed_args.ldBk;
|
|
|
|
int rank_k = int(k.size());
|
|
|
|
|
|
|
|
auto& l = parsed_args.L;
|
|
|
|
auto& ldAl = parsed_args.ldAl;
|
|
|
|
auto& ldBl = parsed_args.ldBl;
|
|
|
|
auto& ldCl = parsed_args.ldCl;
|
|
|
|
int rank_l = int(l.size());
|
|
|
|
|
|
|
|
if ((rank_m > MaxRank_M) || (rank_n > MaxRank_N) || (rank_k > MaxRank_K) || (rank_l > MaxRank_L)) {
|
|
|
|
std::cerr << "ERROR: Input has more modes than statically configured.";
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check that the user input major stride match the static major strides.
|
|
|
|
if (example::is_left_major<RowModeStridesA>() && (ldAm[0] != 1)) {
|
|
|
|
std::cerr << "ERROR: A_M0 is expected to be major, but was not in the provided input!\n";
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (example::is_left_major<RedModeStridesA>() && (ldAk[0] != 1)) {
|
|
|
|
std::cerr << "ERROR: A_K0 is expected to be major, but was not in the provided input!\n";
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (example::is_left_major<ColModeStridesB>() && (ldBn[0] != 1)) {
|
|
|
|
std::cerr << "ERROR: B_N0 is expected to be major, but was not in the provided input!\n";
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (example::is_left_major<RedModeStridesB>() && (ldBk[0] != 1)) {
|
|
|
|
std::cerr << "ERROR: B_K0 is expected to be major, but was not in the provided input!\n";
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Convert to `cute::Tuple`s and set up arguments
|
|
|
|
auto M = make_int_tuple<MaxRank_M>(m.data(), rank_m, 1);
|
|
|
|
auto dAm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesA>()>(ldAm.data(), rank_m);
|
|
|
|
auto dCm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesC>()>(ldCm.data(), rank_m);
|
|
|
|
|
|
|
|
auto N = make_int_tuple<MaxRank_N>(n.data(), rank_n, 1);
|
|
|
|
auto dBn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesB>()>(ldBn.data(), rank_n);
|
|
|
|
auto dCn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesC>()>(ldCn.data(), rank_n);
|
|
|
|
|
|
|
|
auto K = make_int_tuple<MaxRank_K>(k.data(), rank_k, 1);
|
|
|
|
auto dAk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesA>()>(ldAk.data(), rank_k);
|
|
|
|
auto dBk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesB>()>(ldBk.data(), rank_k);
|
|
|
|
|
|
|
|
auto L = make_int_tuple<MaxRank_L>(l.data(), rank_l, 1);
|
|
|
|
auto dAl = make_int_tuple<MaxRank_L>(ldAl.data(), rank_l, 0);
|
|
|
|
auto dBl = make_int_tuple<MaxRank_L>(ldBl.data(), rank_l, 0);
|
|
|
|
auto dCl = make_int_tuple<MaxRank_L>(ldCl.data(), rank_l, 0);
|
|
|
|
|
|
|
|
// Concat tuples to turn it into rank-4 problem shape and rank-3 strides, just like GEMM
|
|
|
|
auto problem_shape = make_shape(M, N, K, L);
|
|
|
|
StrideA stride_A = make_stride(dAm, dAk, dAl);
|
|
|
|
StrideB stride_B = make_stride(dBn, dBk, dBl);
|
|
|
|
StrideC stride_C = make_stride(dCm, dCn, dCl);
|
|
|
|
StrideD stride_D = stride_C;
|
|
|
|
|
|
|
|
auto alpha = ElementEpilogue(1.0f);
|
|
|
|
auto beta = ElementEpilogue(1.0f);
|
|
|
|
|
|
|
|
//
|
|
|
|
// Allocate and init tensors
|
|
|
|
//
|
|
|
|
auto M_size = std::accumulate(std::begin(m), std::end(m), 1, std::multiplies<>{});
|
|
|
|
auto N_size = std::accumulate(std::begin(n), std::end(n), 1, std::multiplies<>{});
|
|
|
|
auto K_size = std::accumulate(std::begin(k), std::end(k), 1, std::multiplies<>{});
|
|
|
|
auto L_size = std::accumulate(std::begin(l), std::end(l), 1, std::multiplies<>{});
|
|
|
|
|
|
|
|
thrust::host_vector<ElementA> h_A(M_size * K_size * L_size);
|
|
|
|
thrust::host_vector<ElementB> h_B(N_size * K_size * L_size);
|
|
|
|
thrust::host_vector<ElementC> h_C(M_size * N_size * L_size);
|
|
|
|
thrust::host_vector<ElementD> h_D(M_size * N_size * L_size);
|
|
|
|
|
|
|
|
// Note: the cast to int here is to avoid false-negative ref-checks which can
|
|
|
|
// occur due to floating point arithmetic not being purely associative.
|
|
|
|
for (auto& a : h_A) a = ElementA(int(4*(rand() / double(RAND_MAX)) - 1));
|
|
|
|
for (auto& b : h_B) b = ElementB(int(4*(rand() / double(RAND_MAX)) - 1));
|
|
|
|
for (auto& c : h_C) c = ElementC(int(4*(rand() / double(RAND_MAX)) - 1));
|
|
|
|
for (auto& d : h_D) d = ElementD(-1);
|
|
|
|
|
|
|
|
thrust::device_vector<ElementA> d_A = h_A;
|
|
|
|
thrust::device_vector<ElementB> d_B = h_B;
|
|
|
|
thrust::device_vector<ElementC> d_C = h_C;
|
|
|
|
thrust::device_vector<ElementD> cutlass_result = h_D;
|
|
|
|
thrust::device_vector<ElementD> reference_result = h_D;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Compute GETT
|
|
|
|
//
|
|
|
|
auto status = example::gett_kernel(
|
|
|
|
problem_shape,
|
|
|
|
d_A.data().get(), stride_A,
|
|
|
|
d_B.data().get(), stride_B,
|
|
|
|
ElementAccumulator{},
|
|
|
|
d_C.data().get(), stride_C,
|
|
|
|
cutlass_result.data().get(), stride_D,
|
|
|
|
alpha, beta);
|
|
|
|
|
|
|
|
if (cutlass::Status::kSuccess != status) {
|
|
|
|
std::cerr << "ERROR: GETT operator launch failed.\n";
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto cuda_err = cudaDeviceSynchronize();
|
|
|
|
if (cudaSuccess != cuda_err) {
|
|
|
|
std::cerr << "ERROR: GETT operator execution failed. with error :";
|
|
|
|
std::cerr << cudaGetErrorString(cuda_err) << "\n";
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
|
|
|
// Verify
|
|
|
|
//
|
|
|
|
|
|
|
|
cutlass::reference::device::gett(
|
|
|
|
problem_shape,
|
|
|
|
d_A.data().get(), stride_A,
|
|
|
|
d_B.data().get(), stride_B,
|
|
|
|
ElementAccumulator{},
|
|
|
|
d_C.data().get(), stride_C,
|
|
|
|
reference_result.data().get(), stride_D,
|
|
|
|
alpha, beta);
|
|
|
|
|
|
|
|
cuda_err = cudaDeviceSynchronize();
|
|
|
|
if (cudaSuccess != cuda_err) {
|
|
|
|
std::cerr << "ERROR: GETT reference execution failed. with error :";
|
|
|
|
std::cerr << cudaGetErrorString(cuda_err) << "\n";
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
|
|
|
bool passed = cutlass::reference::device::BlockCompareEqual(
|
|
|
|
reference_result.data().get(), cutlass_result.data().get(), cutlass_result.size());
|
|
|
|
if (passed) {
|
|
|
|
std::cout << "GETT verification passed.\n";
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
else {
|
|
|
|
std::cerr << "ERROR: GETT verification failed! Printing detailed stats.\n";
|
|
|
|
h_D = reference_result;
|
|
|
|
thrust::host_vector<ElementD> h_cutlass_result = cutlass_result;
|
|
|
|
print_relative_error(h_cutlass_result.size(), h_cutlass_result.data(), h_D.data());
|
|
|
|
|
|
|
|
std::cout << "StrideA: "; print(stride_A); std::cout << '\n';
|
|
|
|
std::cout << "StrideB: "; print(stride_B); std::cout << '\n';
|
|
|
|
std::cout << "StrideC: "; print(stride_C); std::cout << '\n';
|
|
|
|
std::cout << "StrideD: "; print(stride_D); std::cout << '\n';
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
#else
|
|
|
|
std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM90_SUPPORTED is defined.\n";
|
|
|
|
return 0;
|
|
|
|
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
|
|
|
}
|