CUTLASS 3.0 Hopper GEMMs are GETTs in disguise (#897)
This commit is contained in:
parent
1eef5c3cf1
commit
15d9d31f1f
371
examples/51_hopper_gett/51_hopper_gett.cu
Normal file
371
examples/51_hopper_gett/51_hopper_gett.cu
Normal file
@ -0,0 +1,371 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example of a GETT targeting Hopper tensor cores using the CUTLASS 3.x API.
|
||||
|
||||
CUTLASS has long provided implementations of Generalized Matrix times Matrix (GEMM) kernels.
|
||||
However, a plethora of workloads compute on higher ranked tensors. Products of such tensors,
|
||||
called tensor contractions, can be executed as multiple batched GEMMs, however, they can be
|
||||
further accelerated with kernels that natively operate on these higher ranked tensors to
|
||||
perform Generalized Tensor times Tensor contractions (GETT). CuTe's hierarchical layouts
|
||||
and CUTLASS 3.0's unified micro-kernels make implementation of GETTs trivial. In this example,
|
||||
we show how CUTLASS 3.0, CuTe, and Hopper's TMA feature together can accelerate GETTs while
|
||||
making the process of authoring custom GETT kernels easier than ever before.
|
||||
|
||||
The modes of a tensor that participate in a GETT can be fundamentally grouped into four
|
||||
semantic categories. The contraction modes (or K-modes) only appear in the A and B (left and right)
|
||||
inputs but not in the C output tensor. Row modes (or M-modes) only appear in the left
|
||||
input tensor (A) and the output tensor (C). Column modes (or N-modes) only appear in the
|
||||
right (B) input tensor and the output tensor (C). Batch modes (or L-modes) appear in all
|
||||
input and output tensors. If we fold the many modes of a tensor contraction into these four
|
||||
categories, it would allow us to represent the input and output tensors as rank-3 "matrices"
|
||||
that can be computed upon as if we were computing a batched GEMM!
|
||||
|
||||
This is exactly what CuTe's hierarchical layout representation allows us to do! Instead of having
|
||||
simple integers as strides for these four modes, we can have nested strides for each of these
|
||||
semantic categories that themselves have multiple modes within them -- multi-mode strides!
|
||||
In CUTLASS 3.0, all one has to do to take advantage of this capability is to substitute the
|
||||
required multi-mode strides instead of the default ones provided by gemm::detail::TagToStrideX.
|
||||
|
||||
In the following example, we illustrate how every Hopper GEMM in CUTLASS 3.0 is a GETT in disguise.
|
||||
We begin by defining the four modes detailed above as Row, Col (column), Red (reduction), and
|
||||
Bat (batch) strides, which we then nest for each of the in/out tensors to create our rank-3 stride
|
||||
tuples. Note that although we do not define the problem shape type explicitely, it too remains a
|
||||
rank-4 shape tuple just like any other batched GEMM, but instead with multi-mode shapes for each
|
||||
of the four corresponding multi-modes within it. After this, the same CollectiveMma and
|
||||
CollectiveBuilder we describe in examples 50 and 49 are used to create our kernel type. Nothing
|
||||
else changes from a user's point of view. Note that multi-mode strides do not affect our
|
||||
specializations in any way -- the lexical spelling of our kernels remains the same. The
|
||||
only difference between a CUTLASS 3 batched GEMM and GETT are the instaced CuTe Layouts.
|
||||
|
||||
CollectiveBuilders rely on detecting the static-1 in the stride tuples to determine the major mode,
|
||||
which is what the example demonstrates. However, it is possible to have all modes be dynamic as well
|
||||
if the user assembles a CollectiveMma manually and ensures that the runtime strides are compatible
|
||||
with the static micro-kernel of the collective (TiledMma, TiledCopy, and smem layouts). On the other
|
||||
hand, a user can have more than one static stride too (which need not correspond to the major mode).
|
||||
|
||||
In particular, this example demonstrates a GETT where the 0th M-mode (M0) in A and the 0th K-mode (K0)
|
||||
in B are major. All other combinations of major modes are supported, with the exception of mixed
|
||||
K-major scenarios where both A and B are K-major (e.g. K0 is major in A but K1 is major in B).
|
||||
NVIDIA Hopper architecture's TMA feature makes the predictaion required to implement these complicated
|
||||
kernels trivial, as it is all handled by TMA itself without requiring any programmer effort.
|
||||
|
||||
Example executions, where the stride order defines the major-order (major on the left):
|
||||
51_hopper_gett --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096
|
||||
51_hopper_gett --modeC=l,m,n --modeA=m,l,k --modeB=k,n,l --extents=m:128,n:128,k:128,l:64
|
||||
51_hopper_gett --modeC=m,a,b,p,q,n,l --modeA=m,l,b,k,a --modeB=k,n,p,q,l --extents=m:32,a:32,b:3,n:128,k:128,l:4,p:3,q:3
|
||||
*/
|
||||
|
||||
#include "gett_kernel.cuh"
|
||||
#include "thrust/host_vector.h"
|
||||
#include "thrust/device_vector.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/gett_commandline.hpp"
|
||||
#include "cutlass/util/reference/device/gett.hpp"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
|
||||
namespace example {
|
||||
|
||||
// Returns true if the left-most value in the tuple is statically known to be 1
|
||||
template<class Stride>
|
||||
constexpr bool
|
||||
is_left_major() {
|
||||
// Account for stride types with and without batch mode and batch modes with static zero stride
|
||||
return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value;
|
||||
}
|
||||
|
||||
// Same as cute::make_int_tuple but inserts a major stride (Int<1>) for the leftmost mode if required
|
||||
template <int Rank, bool IsMajor, class Indexable>
|
||||
static constexpr
|
||||
auto
|
||||
make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) {
|
||||
static_assert(Rank > 1);
|
||||
if constexpr (IsMajor) {
|
||||
return cute::transform(cute::make_seq<Rank>{}, [&](auto i) {
|
||||
if constexpr (i == 0) {
|
||||
return cute::Int<1>{};
|
||||
}
|
||||
else {
|
||||
return i < n ? t[i] : init_default;
|
||||
}
|
||||
});
|
||||
}
|
||||
else {
|
||||
return cute::make_int_tuple<Rank>(t, n, init_default);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace example
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int
|
||||
main(int argc, char const* argv[]) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
using namespace cute;
|
||||
|
||||
if (argc != 5) {
|
||||
std::cout << "Number of command line args must be 4.\n";
|
||||
cutlass::GettCommandLine::print_usage();
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Define the stride types for A, B, C, and D
|
||||
//
|
||||
|
||||
// Stride for A (left input). If reduction mode is major, same must be major in B
|
||||
// For this example, M0 is major in A.
|
||||
using RowModeStridesA = cute::Stride<cute::Int<1>, int64_t, int64_t, int64_t>;
|
||||
using RedModeStridesA = cute::Stride<int64_t, int64_t, int64_t>;
|
||||
using BatModeStridesA = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
|
||||
// Stride for B (right input). If reduction mode is major, same must be major in A
|
||||
// For this example, K0 is major in B.
|
||||
using ColModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
using RedModeStridesB = cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
||||
using BatModeStridesB = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
|
||||
// Strides for output, which can all be dynamic.
|
||||
using RowModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
using ColModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
using BatModeStridesC = cute::Stride<int64_t, int64_t, int64_t, int64_t>;
|
||||
|
||||
// Assmble our rank-3 multi-mode strides for the in/out tensors
|
||||
using StrideA = cute::Stride<RowModeStridesA, RedModeStridesA, BatModeStridesA>;
|
||||
using StrideB = cute::Stride<ColModeStridesB, RedModeStridesB, BatModeStridesB>;
|
||||
using StrideC = cute::Stride<RowModeStridesC, ColModeStridesC, BatModeStridesC>;
|
||||
|
||||
// Note: C and D share strides here for simplicity.
|
||||
// In general, they need not have the same layout.
|
||||
using StrideD = StrideC;
|
||||
|
||||
//
|
||||
// Define element types for tensors and intermediate values
|
||||
//
|
||||
using ElementA = cutlass::half_t;
|
||||
using ElementB = cutlass::half_t;
|
||||
using ElementC = cutlass::half_t;
|
||||
using ElementD = float;
|
||||
using ElementAccumulator = float;
|
||||
using ElementEpilogue = float;
|
||||
|
||||
// The following constexpr values set the max number of modes in each MNKL mode
|
||||
constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes
|
||||
constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes
|
||||
constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes
|
||||
constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes
|
||||
static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{}));
|
||||
static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{}));
|
||||
static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{}));
|
||||
static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{}));
|
||||
static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{}));
|
||||
|
||||
// Parse command line to get modes, extents, and strides
|
||||
cutlass::GettCommandLine cmd;
|
||||
auto parsed_args = cmd.parse(argc, argv, true);
|
||||
|
||||
auto& m = parsed_args.M;
|
||||
auto& ldAm = parsed_args.ldAm;
|
||||
auto& ldCm = parsed_args.ldCm;
|
||||
int rank_m = int(m.size());
|
||||
|
||||
auto& n = parsed_args.N;
|
||||
auto& ldBn = parsed_args.ldBn;
|
||||
auto& ldCn = parsed_args.ldCn;
|
||||
int rank_n = int(n.size());
|
||||
|
||||
auto& k = parsed_args.K;
|
||||
auto& ldAk = parsed_args.ldAk;
|
||||
auto& ldBk = parsed_args.ldBk;
|
||||
int rank_k = int(k.size());
|
||||
|
||||
auto& l = parsed_args.L;
|
||||
auto& ldAl = parsed_args.ldAl;
|
||||
auto& ldBl = parsed_args.ldBl;
|
||||
auto& ldCl = parsed_args.ldCl;
|
||||
int rank_l = int(l.size());
|
||||
|
||||
if ((rank_m > MaxRank_M) || (rank_n > MaxRank_N) || (rank_k > MaxRank_K) || (rank_l > MaxRank_L)) {
|
||||
std::cerr << "ERROR: Input has more modes than statically configured.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check that the user input major stride match the static major strides.
|
||||
if (example::is_left_major<RowModeStridesA>() && (ldAm[0] != 1)) {
|
||||
std::cerr << "ERROR: A_M0 is expected to be major, but was not in the provided input!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (example::is_left_major<RedModeStridesA>() && (ldAk[0] != 1)) {
|
||||
std::cerr << "ERROR: A_K0 is expected to be major, but was not in the provided input!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (example::is_left_major<ColModeStridesB>() && (ldBn[0] != 1)) {
|
||||
std::cerr << "ERROR: B_N0 is expected to be major, but was not in the provided input!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (example::is_left_major<RedModeStridesB>() && (ldBk[0] != 1)) {
|
||||
std::cerr << "ERROR: B_K0 is expected to be major, but was not in the provided input!\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Convert to `cute::Tuple`s and set up arguments
|
||||
auto M = make_int_tuple<MaxRank_M>(m.data(), rank_m, 1);
|
||||
auto dAm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesA>()>(ldAm.data(), rank_m);
|
||||
auto dCm = example::make_stride_tuple<MaxRank_M, example::is_left_major<RowModeStridesC>()>(ldCm.data(), rank_m);
|
||||
|
||||
auto N = make_int_tuple<MaxRank_N>(n.data(), rank_n, 1);
|
||||
auto dBn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesB>()>(ldBn.data(), rank_n);
|
||||
auto dCn = example::make_stride_tuple<MaxRank_N, example::is_left_major<ColModeStridesC>()>(ldCn.data(), rank_n);
|
||||
|
||||
auto K = make_int_tuple<MaxRank_K>(k.data(), rank_k, 1);
|
||||
auto dAk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesA>()>(ldAk.data(), rank_k);
|
||||
auto dBk = example::make_stride_tuple<MaxRank_K, example::is_left_major<RedModeStridesB>()>(ldBk.data(), rank_k);
|
||||
|
||||
auto L = make_int_tuple<MaxRank_L>(l.data(), rank_l, 1);
|
||||
auto dAl = make_int_tuple<MaxRank_L>(ldAl.data(), rank_l, 0);
|
||||
auto dBl = make_int_tuple<MaxRank_L>(ldBl.data(), rank_l, 0);
|
||||
auto dCl = make_int_tuple<MaxRank_L>(ldCl.data(), rank_l, 0);
|
||||
|
||||
// Concat tuples to turn it into rank-4 problem shape and rank-3 strides, just like GEMM
|
||||
auto problem_shape = make_shape(M, N, K, L);
|
||||
StrideA stride_A = make_stride(dAm, dAk, dAl);
|
||||
StrideB stride_B = make_stride(dBn, dBk, dBl);
|
||||
StrideC stride_C = make_stride(dCm, dCn, dCl);
|
||||
StrideD stride_D = stride_C;
|
||||
|
||||
auto alpha = ElementEpilogue(1.0f);
|
||||
auto beta = ElementEpilogue(1.0f);
|
||||
|
||||
//
|
||||
// Allocate and init tensors
|
||||
//
|
||||
auto M_size = std::accumulate(std::begin(m), std::end(m), 1, std::multiplies<>{});
|
||||
auto N_size = std::accumulate(std::begin(n), std::end(n), 1, std::multiplies<>{});
|
||||
auto K_size = std::accumulate(std::begin(k), std::end(k), 1, std::multiplies<>{});
|
||||
auto L_size = std::accumulate(std::begin(l), std::end(l), 1, std::multiplies<>{});
|
||||
|
||||
thrust::host_vector<ElementA> h_A(M_size * K_size * L_size);
|
||||
thrust::host_vector<ElementB> h_B(N_size * K_size * L_size);
|
||||
thrust::host_vector<ElementC> h_C(M_size * N_size * L_size);
|
||||
thrust::host_vector<ElementD> h_D(M_size * N_size * L_size);
|
||||
|
||||
// Note: the cast to int here is to avoid false-negative ref-checks which can
|
||||
// occur due to floating point arithmetic not being purely associative.
|
||||
for (auto& a : h_A) a = ElementA(int(4*(rand() / double(RAND_MAX)) - 1));
|
||||
for (auto& b : h_B) b = ElementB(int(4*(rand() / double(RAND_MAX)) - 1));
|
||||
for (auto& c : h_C) c = ElementC(int(4*(rand() / double(RAND_MAX)) - 1));
|
||||
for (auto& d : h_D) d = ElementD(-1);
|
||||
|
||||
thrust::device_vector<ElementA> d_A = h_A;
|
||||
thrust::device_vector<ElementB> d_B = h_B;
|
||||
thrust::device_vector<ElementC> d_C = h_C;
|
||||
thrust::device_vector<ElementD> cutlass_result = h_D;
|
||||
thrust::device_vector<ElementD> reference_result = h_D;
|
||||
|
||||
//
|
||||
// Compute GETT
|
||||
//
|
||||
auto status = example::gett_kernel(
|
||||
problem_shape,
|
||||
d_A.data().get(), stride_A,
|
||||
d_B.data().get(), stride_B,
|
||||
ElementAccumulator{},
|
||||
d_C.data().get(), stride_C,
|
||||
cutlass_result.data().get(), stride_D,
|
||||
alpha, beta);
|
||||
|
||||
if (cutlass::Status::kSuccess != status) {
|
||||
std::cerr << "ERROR: GETT operator launch failed.\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto cuda_err = cudaDeviceSynchronize();
|
||||
if (cudaSuccess != cuda_err) {
|
||||
std::cerr << "ERROR: GETT operator execution failed. with error :";
|
||||
std::cerr << cudaGetErrorString(cuda_err) << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::gett(
|
||||
problem_shape,
|
||||
d_A.data().get(), stride_A,
|
||||
d_B.data().get(), stride_B,
|
||||
ElementAccumulator{},
|
||||
d_C.data().get(), stride_C,
|
||||
reference_result.data().get(), stride_D,
|
||||
alpha, beta);
|
||||
|
||||
cuda_err = cudaDeviceSynchronize();
|
||||
if (cudaSuccess != cuda_err) {
|
||||
std::cerr << "ERROR: GETT reference execution failed. with error :";
|
||||
std::cerr << cudaGetErrorString(cuda_err) << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(
|
||||
reference_result.data().get(), cutlass_result.data().get(), cutlass_result.size());
|
||||
if (passed) {
|
||||
std::cout << "GETT verification passed.\n";
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
std::cerr << "ERROR: GETT verification failed! Printing detailed stats.\n";
|
||||
h_D = reference_result;
|
||||
thrust::host_vector<ElementD> h_cutlass_result = cutlass_result;
|
||||
print_relative_error(h_cutlass_result.size(), h_cutlass_result.data(), h_D.data());
|
||||
|
||||
std::cout << "StrideA: "; print(stride_A); std::cout << '\n';
|
||||
std::cout << "StrideB: "; print(stride_B); std::cout << '\n';
|
||||
std::cout << "StrideC: "; print(stride_C); std::cout << '\n';
|
||||
std::cout << "StrideD: "; print(stride_D); std::cout << '\n';
|
||||
return 1;
|
||||
}
|
||||
#else
|
||||
std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM90_SUPPORTED is defined.\n";
|
||||
return 0;
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
}
|
32
examples/51_hopper_gett/CMakeLists.txt
Normal file
32
examples/51_hopper_gett/CMakeLists.txt
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
51_hopper_gett
|
||||
51_hopper_gett.cu
|
||||
)
|
136
examples/51_hopper_gett/gett_kernel.cuh
Normal file
136
examples/51_hopper_gett/gett_kernel.cuh
Normal file
@ -0,0 +1,136 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
namespace example {
|
||||
|
||||
//
|
||||
// GETT entry point
|
||||
//
|
||||
template <
|
||||
class ProblemShapeMNKL,
|
||||
class ElementA,
|
||||
class StrideA,
|
||||
class ElementB,
|
||||
class StrideB,
|
||||
class ElementAccumulator,
|
||||
class ElementC,
|
||||
class StrideC,
|
||||
class ElementD,
|
||||
class StrideD,
|
||||
class ElementEpilogue>
|
||||
cutlass::Status
|
||||
gett_kernel(
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
ElementA const* ptr_A, StrideA stride_a_mkl,
|
||||
ElementB const* ptr_B, StrideB stride_b_nkl,
|
||||
ElementAccumulator _,
|
||||
ElementC const* ptr_C, StrideC stride_c_mnl,
|
||||
ElementD * ptr_D, StrideD stride_d_mnl,
|
||||
ElementEpilogue alpha, ElementEpilogue beta,
|
||||
cudaStream_t stream = 0) {
|
||||
using namespace cute;
|
||||
|
||||
// TileShape -- GETT configuration
|
||||
// Specify the number of elements to take from each mode
|
||||
// BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...)
|
||||
|
||||
// Take 128 from m0, 128 from n0, 64 from k0
|
||||
using TileShape = Shape<Shape<_128>, Shape<_128>, Shape<_64>>;
|
||||
|
||||
/* Other examples:
|
||||
* Take 32 elements from m0 and 4 elements from m1
|
||||
* Take 64 elements from n0 and 2 elements from n1
|
||||
* Take 8 elements from k0 and 8 elements from k1
|
||||
**/
|
||||
// using TileShape = Shape<Shape<_32,_4>, Shape<_64,_2>, Shape<_8,_8>>;
|
||||
|
||||
using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, ElementC>;
|
||||
|
||||
// No changes are required to the default epilogue
|
||||
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
|
||||
StrideC,
|
||||
StrideD,
|
||||
EpilogueThreadOp>;
|
||||
|
||||
// CollectiveMma for GETTs can be built using the CollectiveBuilders
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementA, StrideA, 128 / cutlass::sizeof_bits<ElementA>::value,
|
||||
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
|
||||
ElementAccumulator,
|
||||
TileShape, Shape<_1,_2,_1>,
|
||||
cutlass::gemm::collective::StageCountAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto
|
||||
>::CollectiveOp;
|
||||
|
||||
// The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM
|
||||
using GettKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShapeMNKL,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using GettOperator = cutlass::gemm::device::GemmUniversalAdapter<GettKernel>;
|
||||
|
||||
typename GettOperator::Arguments args {
|
||||
cutlass::gemm::GemmUniversalMode::kBatched,
|
||||
problem_shape_mnkl,
|
||||
ptr_A, stride_a_mkl,
|
||||
ptr_B, stride_b_nkl,
|
||||
{ ptr_C, stride_c_mnl, ptr_D, stride_d_mnl, {alpha, beta} }
|
||||
};
|
||||
|
||||
#if CUTLASS_DEBUG_TRACE_LEVEL > 0
|
||||
print("Problem shape:");
|
||||
print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n");
|
||||
print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n");
|
||||
print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n");
|
||||
print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n");
|
||||
print("TileSape:"); print(TileShape{}); print("\n");
|
||||
#endif
|
||||
|
||||
GettOperator op;
|
||||
return op(args, stream);
|
||||
}
|
||||
|
||||
} // namespace example
|
@ -131,6 +131,7 @@ foreach(EXAMPLE
|
||||
48_hopper_warp_specialized_gemm
|
||||
49_hopper_gemm_schedules_with_collective_builder
|
||||
50_hopper_gemm_with_epilogue_swizzle
|
||||
51_hopper_gett
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
@ -762,7 +762,17 @@ make_tma_copy(CopyOp,
|
||||
print("layout_tv : "); print(layout_tv); print("\n");
|
||||
#endif
|
||||
|
||||
return TiledCopy<Copy_Atom<Traits,T>, decltype(layout_tv), decltype(cta_tile)>{tma_desc, gmem_stride_bases};
|
||||
// If CTA_Tile and SLayout are incompatible, product_each makes sure
|
||||
// that the TiledCopy generates consistent accesses.
|
||||
auto cta_tile_tiled = [&]() {
|
||||
if constexpr (compatible(shape(CTA_Tile{}), shape(SLayout{}))) {
|
||||
return cta_tile;
|
||||
} else {
|
||||
return product_each(cta_tile);
|
||||
}
|
||||
}();
|
||||
|
||||
return TiledCopy<Copy_Atom<Traits,T>, decltype(layout_tv), decltype(cta_tile_tiled)>{tma_desc, gmem_stride_bases};
|
||||
}
|
||||
|
||||
// Explicit defaulting
|
||||
|
@ -61,7 +61,7 @@ template <class ElementA, class LayoutA>
|
||||
constexpr cute::GMMA::Major
|
||||
tag_to_gmma_major_A() {
|
||||
// MN major mode is only valid for non-TF32 and non-int MMAs
|
||||
if constexpr (std::is_same_v<LayoutA, cutlass::layout::ColumnMajor> &&
|
||||
if constexpr (cutlass::gemm::detail::is_mn_major_A<LayoutA>() &&
|
||||
not std::is_same_v<ElementA, tfloat32_t> &&
|
||||
not std::is_same_v<ElementA, int8_t> &&
|
||||
not std::is_same_v<ElementA, uint8_t>) {
|
||||
@ -77,7 +77,7 @@ template <class ElementB, class LayoutB>
|
||||
constexpr cute::GMMA::Major
|
||||
tag_to_gmma_major_B() {
|
||||
// MN major mode is only valid for non-TF32 and non-int MMAs
|
||||
if constexpr (std::is_same_v<LayoutB, cutlass::layout::RowMajor> &&
|
||||
if constexpr (cutlass::gemm::detail::is_mn_major_B<LayoutB>() &&
|
||||
not std::is_same_v<ElementB, tfloat32_t> &&
|
||||
not std::is_same_v<ElementB, int8_t> &&
|
||||
not std::is_same_v<ElementB, uint8_t>) {
|
||||
@ -113,7 +113,7 @@ make_cp_async_gmem_tiled_copy() {
|
||||
|
||||
// Maximize the number of threads along the gmem major mode to promote coalesced reads
|
||||
// While making sure our thread layout tiles the threadblock tile evenly
|
||||
if constexpr (cute::size<1>(StrideType{}) == 1) {
|
||||
if constexpr (cutlass::gemm::detail::is_k_major<StrideType>()) {
|
||||
// K major thread layout for K major gmem
|
||||
constexpr int threads_major = TileSizeK / Alignment;
|
||||
constexpr int threads_minor = ThreadCount / threads_major;
|
||||
@ -126,7 +126,7 @@ make_cp_async_gmem_tiled_copy() {
|
||||
Stride<Int<threads_major>, _1>>{},
|
||||
Layout<Shape<_1,Int<Alignment>>>{});
|
||||
}
|
||||
else if constexpr (cute::size<0>(StrideType{}) == 1) {
|
||||
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideType>()) {
|
||||
// MN major thread layout for MN major gmem
|
||||
constexpr int threads_major = TileSizeMN / Alignment;
|
||||
constexpr int threads_minor = ThreadCount / threads_major;
|
||||
@ -257,7 +257,8 @@ struct CollectiveBuilder<
|
||||
not std::is_same_v<KernelScheduleType, KernelMultistage> &&
|
||||
// dispatch TN tf32 and int8 kernels only to TMA builder
|
||||
((sizeof(ElementA) == 2 && sizeof(ElementB) == 2) ||
|
||||
(std::is_same_v<GmemLayoutA, layout::RowMajor> && std::is_same_v<GmemLayoutB, layout::ColumnMajor>))>
|
||||
(cutlass::gemm::detail::is_k_major_A<GmemLayoutA>() &&
|
||||
cutlass::gemm::detail::is_k_major_B<GmemLayoutB>()))>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
@ -346,7 +347,8 @@ struct CollectiveBuilder<
|
||||
((sizeof(ElementB) * AlignmentB) % detail::tma_alignment_bytes != 0) ||
|
||||
// dispatch non-TN tf32 and int8 kernels only to cp_async builder
|
||||
((sizeof(ElementA) != 2 || sizeof(ElementB) != 2) &&
|
||||
(not std::is_same_v<GmemLayoutA, layout::RowMajor> || not std::is_same_v<GmemLayoutB, layout::ColumnMajor>))>
|
||||
(not cutlass::gemm::detail::is_k_major_A<GmemLayoutA>() ||
|
||||
not cutlass::gemm::detail::is_k_major_B<GmemLayoutB>()))>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
|
@ -37,8 +37,7 @@
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
|
||||
#include "cute/arch/copy_sm90_tma.hpp"
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
@ -426,7 +425,9 @@ enum class SharedMemoryClearOption {
|
||||
// For each cutlass::layout, provides its corresponding cute stride types, 64b by default
|
||||
|
||||
template <class L>
|
||||
struct TagToStrideA {};
|
||||
struct TagToStrideA {
|
||||
using type = L;
|
||||
};
|
||||
|
||||
// Maps to modes [M, K, L]
|
||||
template <>
|
||||
@ -443,7 +444,9 @@ struct TagToStrideA<layout::ColumnMajor> {
|
||||
};
|
||||
|
||||
template <class L>
|
||||
struct TagToStrideB {};
|
||||
struct TagToStrideB {
|
||||
using type = L;
|
||||
};
|
||||
|
||||
// Maps to modes [N, K, L]
|
||||
template <>
|
||||
@ -479,13 +482,19 @@ using TagToStrideC_t = typename TagToStrideC<LayoutTag>::type;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class Stride>
|
||||
constexpr bool
|
||||
is_mn_major() {
|
||||
// Account for stride types with and without batch mode and batch modes with static zero stride
|
||||
return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value;
|
||||
}
|
||||
|
||||
// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices
|
||||
template<class StrideAC>
|
||||
constexpr
|
||||
auto
|
||||
stride_to_layout_tag_A() {
|
||||
// Account for stride types with and without batch mode and batch modes with static zero stride
|
||||
if constexpr (cute::size<0>(StrideAC{}) == 1) { // M major
|
||||
if constexpr (is_mn_major<StrideAC>()) { // M major
|
||||
return layout::ColumnMajor{};
|
||||
}
|
||||
else { // K major
|
||||
@ -499,8 +508,7 @@ template<class StrideB>
|
||||
constexpr
|
||||
auto
|
||||
stride_to_layout_tag_B() {
|
||||
// Account for stride types with and without batch mode and batch modes with static zero stride
|
||||
if constexpr (cute::size<0>(StrideB{}) == 1) { // N major
|
||||
if constexpr (is_mn_major<StrideB>()) { // N major
|
||||
return layout::RowMajor{};
|
||||
}
|
||||
else { // K major
|
||||
@ -515,12 +523,12 @@ template <class GmemTiledCopy, class Element>
|
||||
constexpr int
|
||||
get_alignment_count_from_gmem_tiled_copy() {
|
||||
// For TMA tiled copies, we know the alignment has to be 128 bits
|
||||
if constexpr (std::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy> ||
|
||||
std::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>) {
|
||||
if constexpr ( std::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy>
|
||||
|| std::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>
|
||||
) {
|
||||
return 128 / sizeof_bits<Element>::value;
|
||||
}
|
||||
else
|
||||
{
|
||||
else {
|
||||
// For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN
|
||||
return GmemTiledCopy::NumValSrc;
|
||||
}
|
||||
@ -551,6 +559,37 @@ using StrideToLayoutTagB_t = typename StrideToLayoutTagB<S>::type;
|
||||
template<class S>
|
||||
using StrideToLayoutTagC_t = typename StrideToLayoutTagC<S>::type;
|
||||
|
||||
template<class Stride>
|
||||
constexpr
|
||||
bool
|
||||
is_k_major() {
|
||||
return ! is_mn_major<Stride>();
|
||||
}
|
||||
|
||||
template<class LayoutA>
|
||||
constexpr bool
|
||||
is_mn_major_A() {
|
||||
return is_mn_major<TagToStrideA_t<LayoutA>>();
|
||||
}
|
||||
|
||||
template<class LayoutB>
|
||||
constexpr bool
|
||||
is_mn_major_B() {
|
||||
return is_mn_major<TagToStrideB_t<LayoutB>>();
|
||||
}
|
||||
|
||||
template<class LayoutA>
|
||||
constexpr bool
|
||||
is_k_major_A() {
|
||||
return is_k_major<TagToStrideA_t<LayoutA>>();
|
||||
}
|
||||
|
||||
template<class LayoutB>
|
||||
constexpr bool
|
||||
is_k_major_B() {
|
||||
return is_k_major<TagToStrideB_t<LayoutB>>();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The following two metafunctions are used to detect whether a `kernel::Gemm` or `kernel::GemmUniversal`
|
||||
|
369
tools/util/include/cutlass/util/gett_commandline.hpp
Normal file
369
tools/util/include/cutlass/util/gett_commandline.hpp
Normal file
@ -0,0 +1,369 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GETT command line parser to gather semantic modes, their stride order, and extents.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <utility>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
// Output shortcuts
|
||||
std::ostream& operator<<(std::ostream& os, std::vector<char> data) {
|
||||
for (auto& a : data) os << a;
|
||||
return os;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::ostream& operator<<(std::ostream& os, std::vector<T> data) {
|
||||
for (auto& a : data) os << a << " ";
|
||||
return os;
|
||||
}
|
||||
|
||||
struct GettCommandLine {
|
||||
struct GettProblem {
|
||||
using extent_type = int;
|
||||
using stride_type = int64_t;
|
||||
|
||||
// Row modes: appear in A and C/D
|
||||
std::vector<extent_type> M;
|
||||
std::vector<stride_type> ldAm;
|
||||
std::vector<stride_type> ldCm;
|
||||
|
||||
// Column modes: appear in B and C/D
|
||||
std::vector<extent_type> N;
|
||||
std::vector<stride_type> ldBn;
|
||||
std::vector<stride_type> ldCn;
|
||||
|
||||
// Reduction modes: appear in A and B
|
||||
std::vector<extent_type> K;
|
||||
std::vector<stride_type> ldAk;
|
||||
std::vector<stride_type> ldBk;
|
||||
|
||||
// Batch modes: appear in all in/out tensors
|
||||
std::vector<extent_type> L;
|
||||
std::vector<stride_type> ldAl;
|
||||
std::vector<stride_type> ldBl;
|
||||
std::vector<stride_type> ldCl;
|
||||
};
|
||||
|
||||
static GettProblem
|
||||
parse(int argc, char const* argv[], bool parse_verbose = false) {
|
||||
using extent_type = typename GettProblem::extent_type;
|
||||
using stride_type = typename GettProblem::stride_type;
|
||||
|
||||
cutlass::CommandLine cmd(argc, argv);
|
||||
|
||||
// modeA
|
||||
std::vector<char> a_mode;
|
||||
cmd.get_cmd_line_arguments("modeA", a_mode);
|
||||
|
||||
// modeB
|
||||
std::vector<char> b_mode;
|
||||
cmd.get_cmd_line_arguments("modeB", b_mode);
|
||||
|
||||
// modeC
|
||||
std::vector<char> c_mode;
|
||||
cmd.get_cmd_line_arguments("modeC", c_mode);
|
||||
|
||||
|
||||
// mode_sizes
|
||||
std::map<char,extent_type> mode_size;
|
||||
// First, initialize all modes in a, b, c to make sure they're in map
|
||||
for (char a : a_mode) mode_size[a] = 1;
|
||||
for (char b : b_mode) mode_size[b] = 1;
|
||||
for (char c : c_mode) mode_size[c] = 1;
|
||||
|
||||
// Then, overwrite the ones in -extent
|
||||
std::vector<std::pair<std::string, std::string> > extent_tokens;
|
||||
cmd.get_cmd_line_argument_pairs("extents", extent_tokens);
|
||||
for (auto e : extent_tokens) {
|
||||
if (std::get<0>(e).size() > 1) {
|
||||
std::cerr << "ERROR: Mode name must only be 1 character long.\n";
|
||||
print_usage();
|
||||
exit(1);
|
||||
}
|
||||
char label = std::get<0>(e)[0];
|
||||
int size = std::stoi(std::get<1>(e));
|
||||
mode_size[label] = size;
|
||||
}
|
||||
|
||||
// Print out symbolic modes and their extents
|
||||
if (parse_verbose) {
|
||||
std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n";
|
||||
for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n";
|
||||
}
|
||||
|
||||
//
|
||||
// Collect/Compute strides
|
||||
//
|
||||
|
||||
std::map<char,stride_type> mode_ldA;
|
||||
std::map<char,stride_type> mode_ldB;
|
||||
std::map<char,stride_type> mode_ldC;
|
||||
|
||||
{
|
||||
stride_type current;
|
||||
|
||||
current = 1;
|
||||
for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; }
|
||||
|
||||
current = 1;
|
||||
for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; }
|
||||
|
||||
current = 1;
|
||||
for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; }
|
||||
}
|
||||
|
||||
//
|
||||
// Collect mode categories
|
||||
//
|
||||
|
||||
std::vector<char> row_mode; // rows
|
||||
std::vector<char> col_mode; // columns
|
||||
std::vector<char> red_mode; // reductions
|
||||
std::vector<char> bat_mode; // batches
|
||||
|
||||
{
|
||||
std::vector<char> a_label = a_mode;
|
||||
std::vector<char> b_label = b_mode;
|
||||
std::vector<char> c_label = c_mode;
|
||||
|
||||
std::sort(std::begin(a_label), std::end(a_label));
|
||||
std::sort(std::begin(b_label), std::end(b_label));
|
||||
std::sort(std::begin(c_label), std::end(c_label));
|
||||
|
||||
// std::set_intersections to find semantic category of each symbolic mode
|
||||
std::set_intersection(std::begin(a_label), std::end(a_label),
|
||||
std::begin(c_label), std::end(c_label),
|
||||
std::back_inserter(row_mode));
|
||||
|
||||
std::set_intersection(std::begin(b_label), std::end(b_label),
|
||||
std::begin(c_label), std::end(c_label),
|
||||
std::back_inserter(col_mode));
|
||||
|
||||
std::set_intersection(std::begin(a_label), std::end(a_label),
|
||||
std::begin(b_label), std::end(b_label),
|
||||
std::back_inserter(red_mode));
|
||||
|
||||
std::set_intersection(std::begin(row_mode), std::end(row_mode),
|
||||
std::begin(col_mode), std::end(col_mode),
|
||||
std::back_inserter(bat_mode));
|
||||
|
||||
// std::set_difference to remove batch modes from other semantic modes
|
||||
for (char l : bat_mode) {
|
||||
row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode));
|
||||
col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode));
|
||||
red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode));
|
||||
}
|
||||
}
|
||||
|
||||
// Print out the semantic association of each symbolic mode
|
||||
if (parse_verbose) {
|
||||
std::cout << " rows : " << row_mode << '\n';
|
||||
std::cout << " cols : " << col_mode << '\n';
|
||||
std::cout << " reds : " << red_mode << '\n';
|
||||
std::cout << " bats : " << bat_mode << '\n';
|
||||
}
|
||||
|
||||
//
|
||||
// Permute modes
|
||||
//
|
||||
|
||||
// Permute the batched modes to promote coalescing
|
||||
// Sort the batched modes by min(ldAl,ldBl) and tie-broken by the size
|
||||
std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) {
|
||||
return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1])
|
||||
< std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]);
|
||||
});
|
||||
// Compute sizes and strides of ordered reduction modes
|
||||
std::vector<extent_type> L;
|
||||
std::vector<stride_type> ldAl;
|
||||
std::vector<stride_type> ldBl;
|
||||
std::vector<stride_type> ldCl;
|
||||
for (char l : bat_mode) {
|
||||
L.push_back(mode_size[l]);
|
||||
ldAl.push_back(mode_ldA[l]);
|
||||
ldBl.push_back(mode_ldB[l]);
|
||||
ldCl.push_back(mode_ldC[l]);
|
||||
}
|
||||
|
||||
// Permute the reduction modes to promote coalescing
|
||||
// Sort the reduction modes by min(ldAk,ldBk) and tie-broken by the size
|
||||
std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) {
|
||||
return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1])
|
||||
< std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]);
|
||||
});
|
||||
// Compute sizes and strides of ordered reduction modes
|
||||
std::vector<extent_type> K;
|
||||
std::vector<stride_type> ldAk;
|
||||
std::vector<stride_type> ldBk;
|
||||
for (char k : red_mode) {
|
||||
K.push_back(mode_size[k]);
|
||||
ldAk.push_back(mode_ldA[k]);
|
||||
ldBk.push_back(mode_ldB[k]);
|
||||
}
|
||||
|
||||
// Permute the row modes to promote coalescing
|
||||
// Sort the row modes by min(ldAm,ldCm) and tie-broken by ldAm
|
||||
std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) {
|
||||
return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1])
|
||||
< std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]);
|
||||
});
|
||||
// Compute sizes and strides of ordered row modes
|
||||
std::vector<extent_type> M;
|
||||
std::vector<stride_type> ldAm;
|
||||
std::vector<stride_type> ldCm;
|
||||
for (char m : row_mode) {
|
||||
M.push_back(mode_size[m]);
|
||||
ldAm.push_back(mode_ldA[m]);
|
||||
ldCm.push_back(mode_ldC[m]);
|
||||
}
|
||||
|
||||
// Permute the col modes to promote coalescing
|
||||
// Sort the col modes by min(ldBn,ldCn) and tie-broken by ldBn
|
||||
std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) {
|
||||
return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1])
|
||||
< std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]);
|
||||
});
|
||||
// Compute sizes and strides of ordered col modes
|
||||
std::vector<extent_type> N;
|
||||
std::vector<stride_type> ldBn;
|
||||
std::vector<stride_type> ldCn;
|
||||
for (char n : col_mode) {
|
||||
N.push_back(mode_size[n]);
|
||||
ldBn.push_back(mode_ldB[n]);
|
||||
ldCn.push_back(mode_ldC[n]);
|
||||
}
|
||||
|
||||
if (parse_verbose) {
|
||||
std::cout << "C_";
|
||||
if (! row_mode.empty()) {
|
||||
std::cout << "(" << row_mode << ")";
|
||||
}
|
||||
if (! col_mode.empty()) {
|
||||
std::cout << "(" << col_mode << ")";
|
||||
}
|
||||
if (! bat_mode.empty()) {
|
||||
std::cout << "(" << bat_mode << ")";
|
||||
}
|
||||
std::cout << " = A_";
|
||||
if (! row_mode.empty()) {
|
||||
std::cout << "(" << row_mode << ")";
|
||||
}
|
||||
if (! red_mode.empty()) {
|
||||
std::cout << "(" << red_mode << ")";
|
||||
}
|
||||
if (! bat_mode.empty()) {
|
||||
std::cout << "(" << bat_mode << ")";
|
||||
}
|
||||
std::cout << " * B_";
|
||||
if (! col_mode.empty()) {
|
||||
std::cout << "(" << col_mode << ")";
|
||||
}
|
||||
if (! red_mode.empty()) {
|
||||
std::cout << "(" << red_mode << ")";
|
||||
}
|
||||
if (! bat_mode.empty()) {
|
||||
std::cout << "(" << bat_mode << ")";
|
||||
}
|
||||
std::cout << '\n';
|
||||
|
||||
int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{});
|
||||
int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{});
|
||||
int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{});
|
||||
int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{});
|
||||
|
||||
std::cout << " M : (" << M_size << ") ";
|
||||
for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " ";
|
||||
std::cout << '\n';
|
||||
std::cout << " N : (" << N_size << ") ";
|
||||
for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " ";
|
||||
std::cout << '\n';
|
||||
std::cout << " K : (" << K_size << ") ";
|
||||
for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " ";
|
||||
std::cout << '\n';
|
||||
std::cout << " L : (" << L_size << ") ";
|
||||
for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " ";
|
||||
std::cout << '\n';
|
||||
|
||||
std::cout << " ldAm : " << ldAm << '\n';
|
||||
std::cout << " ldAk : " << ldAk << '\n';
|
||||
std::cout << " ldAl : " << ldAl << '\n';
|
||||
std::cout << " ldBn : " << ldBn << '\n';
|
||||
std::cout << " ldBk : " << ldBk << '\n';
|
||||
std::cout << " ldBl : " << ldBl << '\n';
|
||||
std::cout << " ldCm : " << ldCm << '\n';
|
||||
std::cout << " ldCn : " << ldCn << '\n';
|
||||
std::cout << " ldCl : " << ldCl << '\n';
|
||||
}
|
||||
|
||||
return {M, ldAm, ldCm,
|
||||
N, ldBn, ldCn,
|
||||
K, ldAk, ldBk,
|
||||
L, ldAl, ldBl, ldCl};
|
||||
}
|
||||
|
||||
static void
|
||||
print_usage() {
|
||||
std::cout <<
|
||||
"GETT problem command line parser:\n"
|
||||
" --modeA=<m0,...>\n"
|
||||
" A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n"
|
||||
" The semantic association of each symbolic mode is determined automatically.\n\n"
|
||||
|
||||
" --modeB=<m0,...>\n"
|
||||
" A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n"
|
||||
" The semantic association of each symbolic mode is determined automatically.\n\n"
|
||||
|
||||
" --modeC=<m0,...>\n"
|
||||
" A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n"
|
||||
" The semantic association of each symbolic mode is determined automatically.\n\n"
|
||||
|
||||
" --extents=<mode:extent,....>\n"
|
||||
" A command delimited list of symbolic mode and its corresponding extent.\n"
|
||||
" Extents are defaulted to 1 if any are not provided.\n\n"
|
||||
|
||||
"Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extent=m:4096,n:4096,k:4096\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass
|
@ -60,7 +60,7 @@ struct matrix_inf_norm_result {
|
||||
// and thus are best passed by reference or const reference.
|
||||
template <typename EngineType, typename LayoutType>
|
||||
matrix_inf_norm_result
|
||||
matrix_inf_norm(const cute::Tensor<EngineType, LayoutType>& host_matrix)
|
||||
matrix_inf_norm(cute::Tensor<EngineType, LayoutType> const& host_matrix)
|
||||
{
|
||||
using std::abs;
|
||||
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
|
||||
@ -68,17 +68,14 @@ matrix_inf_norm(const cute::Tensor<EngineType, LayoutType>& host_matrix)
|
||||
error_type inf_norm = 0.0;
|
||||
bool found_nan = false;
|
||||
|
||||
const auto shape = host_matrix.shape();
|
||||
using index_type = std::decay_t<decltype(cute::get<0>(shape))>;
|
||||
// Computing the infinity norm requires that we be able
|
||||
// to treat the input as a matrix, with rows and columns.
|
||||
static_assert(std::is_integral_v<index_type>);
|
||||
const index_type num_rows = cute::get<0>(shape);
|
||||
const index_type num_cols = cute::get<1>(shape);
|
||||
const int64_t num_rows = cute::size<0>(host_matrix);
|
||||
const int64_t num_cols = cute::size<1>(host_matrix);
|
||||
|
||||
for(index_type i = 0; i < num_rows; ++i) {
|
||||
for(int64_t i = 0; i < num_rows; ++i) {
|
||||
error_type row_abs_sum = 0.0;
|
||||
for(index_type j = 0; j < num_cols; ++j) {
|
||||
for(int64_t j = 0; j < num_cols; ++j) {
|
||||
row_abs_sum += abs(host_matrix(i, j));
|
||||
}
|
||||
if(std::isnan(row_abs_sum)) {
|
||||
@ -94,39 +91,27 @@ matrix_inf_norm(const cute::Tensor<EngineType, LayoutType>& host_matrix)
|
||||
// Infinity norm of (X - Y).
|
||||
template <typename EngineType, typename LayoutType>
|
||||
matrix_inf_norm_result
|
||||
matrix_diff_inf_norm(const cute::Tensor<EngineType, LayoutType>& X,
|
||||
const cute::Tensor<EngineType, LayoutType>& Y)
|
||||
matrix_diff_inf_norm(cute::Tensor<EngineType, LayoutType> const& X,
|
||||
cute::Tensor<EngineType, LayoutType> const& Y)
|
||||
{
|
||||
using std::abs;
|
||||
using error_type = decltype(std::declval<matrix_inf_norm_result>().inf_norm);
|
||||
|
||||
const auto X_shape = X.shape();
|
||||
const auto Y_shape = Y.shape();
|
||||
assert(cute::size<0>(X) == cute::size<0>(Y));
|
||||
assert(cute::size<1>(X) == cute::size<1>(Y));
|
||||
|
||||
using index_type = std::decay_t<decltype(cute::get<0>(X_shape))>;
|
||||
// Computing the infinity norm requires that we be able
|
||||
// to treat the input as a matrix, with rows and columns.
|
||||
static_assert(std::is_integral_v<index_type>);
|
||||
const index_type num_rows = cute::get<0>(X_shape);
|
||||
const index_type num_cols = cute::get<1>(X_shape);
|
||||
|
||||
assert(num_rows == cute::get<0>(Y_shape));
|
||||
assert(num_cols == cute::get<1>(Y_shape));
|
||||
|
||||
auto matrix_ij = [&](const auto& A, std::size_t i, std::size_t j) {
|
||||
return A(i, j);
|
||||
};
|
||||
auto diff_ij = [&](std::size_t i, std::size_t j) {
|
||||
return matrix_ij(X, i, j) - matrix_ij(Y, i, j);
|
||||
};
|
||||
const int64_t num_rows = cute::size<0>(X);
|
||||
const int64_t num_cols = cute::size<1>(X);
|
||||
|
||||
error_type inf_norm = 0.0;
|
||||
bool found_nan = false;
|
||||
|
||||
for(index_type i = 0; i < num_rows; ++i) {
|
||||
for(int64_t i = 0; i < num_rows; ++i) {
|
||||
error_type row_abs_sum = 0.0;
|
||||
for(index_type j = 0; j < num_cols; ++j) {
|
||||
row_abs_sum += abs(diff_ij(i, j));
|
||||
for(int64_t j = 0; j < num_cols; ++j) {
|
||||
row_abs_sum += abs(X(i,j) - Y(i,j));
|
||||
}
|
||||
if(std::isnan(row_abs_sum)) {
|
||||
found_nan = true;
|
||||
@ -140,22 +125,22 @@ matrix_diff_inf_norm(const cute::Tensor<EngineType, LayoutType>& X,
|
||||
|
||||
template <typename EngineType_A, typename LayoutType_A,
|
||||
typename EngineType_B, typename LayoutType_B,
|
||||
typename EngineType_C_computed, typename LayoutType_C_computed,
|
||||
typename EngineType_C_expected, typename LayoutType_C_expected>
|
||||
typename EngineType_C, typename LayoutType_C,
|
||||
typename EngineType_C_ref, typename LayoutType_C_ref>
|
||||
void
|
||||
print_matrix_multiply_mollified_relative_error(
|
||||
const char A_value_type_name[],
|
||||
const cute::Tensor<EngineType_A, LayoutType_A>& A,
|
||||
const char B_value_type_name[],
|
||||
const cute::Tensor<EngineType_B, LayoutType_B>& B,
|
||||
const char C_value_type_name[],
|
||||
const cute::Tensor<EngineType_C_computed, LayoutType_C_computed>& C_computed,
|
||||
const cute::Tensor<EngineType_C_expected, LayoutType_C_expected>& C_expected)
|
||||
char const A_value_type_name[],
|
||||
cute::Tensor<EngineType_A, LayoutType_A> const& A,
|
||||
char const B_value_type_name[],
|
||||
cute::Tensor<EngineType_B, LayoutType_B> const& B,
|
||||
char const C_value_type_name[],
|
||||
cute::Tensor<EngineType_C, LayoutType_C> const& C,
|
||||
cute::Tensor<EngineType_C_ref, LayoutType_C_ref> const& C_ref)
|
||||
{
|
||||
const auto [A_norm, A_has_nan] = matrix_inf_norm(A);
|
||||
const auto [B_norm, B_has_nan] = matrix_inf_norm(B);
|
||||
const auto [C_norm, C_has_nan] = matrix_inf_norm(C_expected);
|
||||
const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C_computed, C_expected);
|
||||
const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref);
|
||||
const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref);
|
||||
|
||||
const auto A_norm_times_B_norm = A_norm * B_norm;
|
||||
const auto relative_error = A_norm_times_B_norm == 0.0 ?
|
||||
@ -164,18 +149,19 @@ print_matrix_multiply_mollified_relative_error(
|
||||
// For expected error bounds, please refer to the LAPACK Users' Guide,
|
||||
// in particular https://netlib.org/lapack/lug/node108.html .
|
||||
// Printing the infinity norm of C is a way to check
|
||||
// that both the function being tested (C_computed)
|
||||
// and the reference implementation (C_expected)
|
||||
// that both the function being tested (C)
|
||||
// and the reference implementation (C_ref)
|
||||
// don't just do nothing (or fill with zeros).
|
||||
using std::cout;
|
||||
cout << "Value type of A: " << A_value_type_name << '\n'
|
||||
using cute::shape;
|
||||
cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n'
|
||||
<< "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n'
|
||||
<< "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n'
|
||||
<< std::scientific
|
||||
<< "Infinity norm of A: " << A_norm << '\n'
|
||||
<< "Value type of B: " << B_value_type_name << '\n'
|
||||
<< "Infinity norm of B: " << B_norm << '\n'
|
||||
<< "Value type of C: " << C_value_type_name << '\n'
|
||||
<< "Infinity norm of C_expected: " << C_norm << '\n'
|
||||
<< "Infinity norm of (C_computed - C_expected): " << diff_norm << '\n';
|
||||
<< "Infinity norm of C: " << C_norm << '\n'
|
||||
<< "Infinity norm of (C - C_ref): " << diff_norm << '\n';
|
||||
|
||||
if(A_norm_times_B_norm == 0.0) {
|
||||
cout << "Mollified relative error: " << relative_error << '\n';
|
||||
@ -183,11 +169,12 @@ print_matrix_multiply_mollified_relative_error(
|
||||
cout << "Relative error: " << relative_error << '\n';
|
||||
}
|
||||
|
||||
cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in C_expected? " << (C_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in (C_computed - C_expected)? "
|
||||
<< (diff_has_nan ? "yes" : "no") << '\n';
|
||||
if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) {
|
||||
cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n'
|
||||
<< "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EngineType, typename LayoutType>
|
||||
@ -233,3 +220,70 @@ auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X)
|
||||
auto X_data_const = const_cast<std::add_const_t< decltype(X_data)> >(X_data);
|
||||
return cute::make_tensor(X_data_const, layout);
|
||||
};
|
||||
|
||||
|
||||
template <typename T1, typename T2>
|
||||
double
|
||||
print_relative_error(
|
||||
std::size_t n,
|
||||
T1 const& data,
|
||||
T2 const& reference,
|
||||
bool print_verbose = false,
|
||||
bool print_error = true) {
|
||||
using std::abs; using std::sqrt;
|
||||
|
||||
// Use either double or complex<double> for error computation
|
||||
using value_type = cute::remove_cvref_t<decltype(reference[0])>;
|
||||
using error_type = std::conditional_t<cute::is_complex<value_type>::value,
|
||||
cute::complex<double>,
|
||||
double>;
|
||||
|
||||
if (print_verbose) {
|
||||
std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl;
|
||||
}
|
||||
|
||||
double eps = 1e-200;
|
||||
|
||||
double tot_error_sq = 0;
|
||||
double tot_norm_sq = 0;
|
||||
double tot_ind_rel_err = 0;
|
||||
double max_ind_rel_err = 0;
|
||||
for (std::size_t i = 0; i < n; ++i)
|
||||
{
|
||||
error_type val = data[i];
|
||||
error_type ref = reference[i];
|
||||
|
||||
double aref = abs(ref);
|
||||
double diff = abs(ref - val);
|
||||
double rel_error = diff / (aref + eps);
|
||||
|
||||
// Individual relative error
|
||||
tot_ind_rel_err += rel_error;
|
||||
|
||||
// Maximum relative error
|
||||
max_ind_rel_err = std::max(max_ind_rel_err, rel_error);
|
||||
|
||||
// Total relative error
|
||||
tot_error_sq += diff * diff;
|
||||
tot_norm_sq += aref * aref;
|
||||
|
||||
if (print_verbose) {
|
||||
std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
printf("Vector reference norm: [%.5e]\n", sqrt(tot_norm_sq));
|
||||
|
||||
double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps));
|
||||
if (print_error)
|
||||
printf("Vector relative error: [%.5e]\n", tot_rel_err);
|
||||
|
||||
double ave_rel_err = tot_ind_rel_err / double(n);
|
||||
if (print_error)
|
||||
printf("Average relative error: [%.5e]\n", ave_rel_err);
|
||||
|
||||
if (print_error)
|
||||
printf("Maximum relative error: [%.5e]\n", max_ind_rel_err);
|
||||
|
||||
return tot_rel_err;
|
||||
}
|
||||
|
146
tools/util/include/cutlass/util/reference/device/gett.hpp
Normal file
146
tools/util/include/cutlass/util/reference/device/gett.hpp
Normal file
@ -0,0 +1,146 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GETT device reference code
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cutlass::reference::device {
|
||||
|
||||
template <
|
||||
class ATensor,
|
||||
class BTensor,
|
||||
class CTensor,
|
||||
class DTensor,
|
||||
class ElementAccumulator,
|
||||
class ElementEpilogue>
|
||||
__global__ static
|
||||
void
|
||||
gett_kernel(
|
||||
DTensor D,
|
||||
ATensor const A,
|
||||
BTensor const B,
|
||||
CTensor const C,
|
||||
ElementEpilogue alpha, ElementEpilogue beta,
|
||||
ElementAccumulator acc_init)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(DTensor::rank == 3, "(M,N,L)");
|
||||
static_assert(ATensor::rank == 3, "(M,K,L)");
|
||||
static_assert(BTensor::rank == 3, "(N,K,L)");
|
||||
static_assert(CTensor::rank == 3, "(M,N,L)");
|
||||
|
||||
assert(size<0>(A) == size<0>(D)); // M
|
||||
assert(size<0>(C) == size<0>(D)); // M
|
||||
assert(size<0>(B) == size<1>(D)); // N
|
||||
assert(size<1>(C) == size<1>(D)); // N
|
||||
assert(size<1>(A) == size<1>(B)); // K
|
||||
assert(size<2>(A) == size<2>(D)); // L
|
||||
assert(size<2>(B) == size<2>(D)); // L
|
||||
assert(size<2>(C) == size<2>(D)); // L
|
||||
|
||||
NumericConverter<ElementAccumulator, typename ATensor::value_type> a_converter;
|
||||
NumericConverter<ElementAccumulator, typename BTensor::value_type> b_converter;
|
||||
NumericConverter<ElementEpilogue, ElementAccumulator> acc_converter;
|
||||
NumericConverter<ElementEpilogue, typename CTensor::value_type> source_converter;
|
||||
NumericConverter<typename DTensor::value_type, ElementEpilogue> output_converter;
|
||||
|
||||
// Thread id to each element of D
|
||||
for (int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
tid < size(D);
|
||||
tid += blockDim.x * gridDim.x) {
|
||||
// (m,n,l) coordinate
|
||||
auto mnl_coord = idx2crd(tid, product_each(shape(D)));
|
||||
auto m = get<0>(mnl_coord);
|
||||
auto n = get<1>(mnl_coord);
|
||||
auto l = get<2>(mnl_coord);
|
||||
|
||||
auto A_ml = A(m,_,l);
|
||||
auto B_nl = B(n,_,l);
|
||||
|
||||
ElementAccumulator accum = ElementAccumulator(0);
|
||||
for (int k = 0; k < size<1>(A); ++k) {
|
||||
ElementAccumulator a = a_converter(A_ml(k));
|
||||
ElementAccumulator b = b_converter(B_nl(k));
|
||||
accum += a * b;
|
||||
}
|
||||
|
||||
ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l)));
|
||||
D(m,n,l) = output_converter(scaled_output);
|
||||
}
|
||||
}
|
||||
|
||||
// Most general version
|
||||
template <
|
||||
class ProblemShapeMNKL,
|
||||
class ElementA,
|
||||
class StrideA,
|
||||
class ElementB,
|
||||
class StrideB,
|
||||
class ElementAccumulator,
|
||||
class ElementC,
|
||||
class StrideC,
|
||||
class ElementD,
|
||||
class StrideD,
|
||||
class ElementEpilogue>
|
||||
void
|
||||
gett(
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
ElementA const* ptr_A, StrideA stride_a_mkl,
|
||||
ElementB const* ptr_B, StrideB stride_b_nkl,
|
||||
ElementAccumulator _,
|
||||
ElementC const* ptr_C, StrideC stride_c_mnl,
|
||||
ElementD * ptr_D, StrideD stride_d_mnl,
|
||||
ElementEpilogue alpha, ElementEpilogue beta,
|
||||
cudaStream_t stream = 0) {
|
||||
using namespace cute;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4);
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
auto N = get<1>(problem_shape_mnkl);
|
||||
auto K = get<2>(problem_shape_mnkl);
|
||||
auto L = get<3>(problem_shape_mnkl);
|
||||
|
||||
// Represent the full tensors
|
||||
auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L)
|
||||
auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L)
|
||||
auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L)
|
||||
auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L)
|
||||
|
||||
dim3 dimBlock(256);
|
||||
dim3 dimGrid(240);
|
||||
gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0));
|
||||
}
|
||||
|
||||
} // namespace cutlass::reference::device
|
Loading…
Reference in New Issue
Block a user