diff --git a/examples/51_hopper_gett/51_hopper_gett.cu b/examples/51_hopper_gett/51_hopper_gett.cu new file mode 100644 index 00000000..e9969505 --- /dev/null +++ b/examples/51_hopper_gett/51_hopper_gett.cu @@ -0,0 +1,371 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example of a GETT targeting Hopper tensor cores using the CUTLASS 3.x API. + + CUTLASS has long provided implementations of Generalized Matrix times Matrix (GEMM) kernels. + However, a plethora of workloads compute on higher ranked tensors. Products of such tensors, + called tensor contractions, can be executed as multiple batched GEMMs, however, they can be + further accelerated with kernels that natively operate on these higher ranked tensors to + perform Generalized Tensor times Tensor contractions (GETT). CuTe's hierarchical layouts + and CUTLASS 3.0's unified micro-kernels make implementation of GETTs trivial. In this example, + we show how CUTLASS 3.0, CuTe, and Hopper's TMA feature together can accelerate GETTs while + making the process of authoring custom GETT kernels easier than ever before. + + The modes of a tensor that participate in a GETT can be fundamentally grouped into four + semantic categories. The contraction modes (or K-modes) only appear in the A and B (left and right) + inputs but not in the C output tensor. Row modes (or M-modes) only appear in the left + input tensor (A) and the output tensor (C). Column modes (or N-modes) only appear in the + right (B) input tensor and the output tensor (C). Batch modes (or L-modes) appear in all + input and output tensors. If we fold the many modes of a tensor contraction into these four + categories, it would allow us to represent the input and output tensors as rank-3 "matrices" + that can be computed upon as if we were computing a batched GEMM! + + This is exactly what CuTe's hierarchical layout representation allows us to do! Instead of having + simple integers as strides for these four modes, we can have nested strides for each of these + semantic categories that themselves have multiple modes within them -- multi-mode strides! + In CUTLASS 3.0, all one has to do to take advantage of this capability is to substitute the + required multi-mode strides instead of the default ones provided by gemm::detail::TagToStrideX. + + In the following example, we illustrate how every Hopper GEMM in CUTLASS 3.0 is a GETT in disguise. + We begin by defining the four modes detailed above as Row, Col (column), Red (reduction), and + Bat (batch) strides, which we then nest for each of the in/out tensors to create our rank-3 stride + tuples. Note that although we do not define the problem shape type explicitely, it too remains a + rank-4 shape tuple just like any other batched GEMM, but instead with multi-mode shapes for each + of the four corresponding multi-modes within it. After this, the same CollectiveMma and + CollectiveBuilder we describe in examples 50 and 49 are used to create our kernel type. Nothing + else changes from a user's point of view. Note that multi-mode strides do not affect our + specializations in any way -- the lexical spelling of our kernels remains the same. The + only difference between a CUTLASS 3 batched GEMM and GETT are the instaced CuTe Layouts. + + CollectiveBuilders rely on detecting the static-1 in the stride tuples to determine the major mode, + which is what the example demonstrates. However, it is possible to have all modes be dynamic as well + if the user assembles a CollectiveMma manually and ensures that the runtime strides are compatible + with the static micro-kernel of the collective (TiledMma, TiledCopy, and smem layouts). On the other + hand, a user can have more than one static stride too (which need not correspond to the major mode). + + In particular, this example demonstrates a GETT where the 0th M-mode (M0) in A and the 0th K-mode (K0) + in B are major. All other combinations of major modes are supported, with the exception of mixed + K-major scenarios where both A and B are K-major (e.g. K0 is major in A but K1 is major in B). + NVIDIA Hopper architecture's TMA feature makes the predictaion required to implement these complicated + kernels trivial, as it is all handled by TMA itself without requiring any programmer effort. + + Example executions, where the stride order defines the major-order (major on the left): + 51_hopper_gett --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096 + 51_hopper_gett --modeC=l,m,n --modeA=m,l,k --modeB=k,n,l --extents=m:128,n:128,k:128,l:64 + 51_hopper_gett --modeC=m,a,b,p,q,n,l --modeA=m,l,b,k,a --modeB=k,n,p,q,l --extents=m:32,a:32,b:3,n:128,k:128,l:4,p:3,q:3 +*/ + +#include "gett_kernel.cuh" +#include "thrust/host_vector.h" +#include "thrust/device_vector.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/gett_commandline.hpp" +#include "cutlass/util/reference/device/gett.hpp" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/print_error.hpp" + +namespace example { + +// Returns true if the left-most value in the tuple is statically known to be 1 +template +constexpr bool +is_left_major() { + // Account for stride types with and without batch mode and batch modes with static zero stride + return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value; +} + +// Same as cute::make_int_tuple but inserts a major stride (Int<1>) for the leftmost mode if required +template +static constexpr +auto +make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) { + static_assert(Rank > 1); + if constexpr (IsMajor) { + return cute::transform(cute::make_seq{}, [&](auto i) { + if constexpr (i == 0) { + return cute::Int<1>{}; + } + else { + return i < n ? t[i] : init_default; + } + }); + } + else { + return cute::make_int_tuple(t, n, init_default); + } +} + +} // namespace example + +////////////////////////////////////////////////////////////////////////////// + +int +main(int argc, char const* argv[]) { +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + using namespace cute; + + if (argc != 5) { + std::cout << "Number of command line args must be 4.\n"; + cutlass::GettCommandLine::print_usage(); + return 0; + } + + // + // Define the stride types for A, B, C, and D + // + + // Stride for A (left input). If reduction mode is major, same must be major in B + // For this example, M0 is major in A. + using RowModeStridesA = cute::Stride, int64_t, int64_t, int64_t>; + using RedModeStridesA = cute::Stride; + using BatModeStridesA = cute::Stride; + + // Stride for B (right input). If reduction mode is major, same must be major in A + // For this example, K0 is major in B. + using ColModeStridesB = cute::Stride; + using RedModeStridesB = cute::Stride, int64_t, int64_t>; + using BatModeStridesB = cute::Stride; + + // Strides for output, which can all be dynamic. + using RowModeStridesC = cute::Stride; + using ColModeStridesC = cute::Stride; + using BatModeStridesC = cute::Stride; + + // Assmble our rank-3 multi-mode strides for the in/out tensors + using StrideA = cute::Stride; + using StrideB = cute::Stride; + using StrideC = cute::Stride; + + // Note: C and D share strides here for simplicity. + // In general, they need not have the same layout. + using StrideD = StrideC; + + // + // Define element types for tensors and intermediate values + // + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = float; + using ElementAccumulator = float; + using ElementEpilogue = float; + + // The following constexpr values set the max number of modes in each MNKL mode + constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes + constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes + constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes + constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes + static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{})); + static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{})); + static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{})); + static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{})); + static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{})); + + // Parse command line to get modes, extents, and strides + cutlass::GettCommandLine cmd; + auto parsed_args = cmd.parse(argc, argv, true); + + auto& m = parsed_args.M; + auto& ldAm = parsed_args.ldAm; + auto& ldCm = parsed_args.ldCm; + int rank_m = int(m.size()); + + auto& n = parsed_args.N; + auto& ldBn = parsed_args.ldBn; + auto& ldCn = parsed_args.ldCn; + int rank_n = int(n.size()); + + auto& k = parsed_args.K; + auto& ldAk = parsed_args.ldAk; + auto& ldBk = parsed_args.ldBk; + int rank_k = int(k.size()); + + auto& l = parsed_args.L; + auto& ldAl = parsed_args.ldAl; + auto& ldBl = parsed_args.ldBl; + auto& ldCl = parsed_args.ldCl; + int rank_l = int(l.size()); + + if ((rank_m > MaxRank_M) || (rank_n > MaxRank_N) || (rank_k > MaxRank_K) || (rank_l > MaxRank_L)) { + std::cerr << "ERROR: Input has more modes than statically configured."; + return 1; + } + + // Check that the user input major stride match the static major strides. + if (example::is_left_major() && (ldAm[0] != 1)) { + std::cerr << "ERROR: A_M0 is expected to be major, but was not in the provided input!\n"; + return 1; + } + + if (example::is_left_major() && (ldAk[0] != 1)) { + std::cerr << "ERROR: A_K0 is expected to be major, but was not in the provided input!\n"; + return 1; + } + + if (example::is_left_major() && (ldBn[0] != 1)) { + std::cerr << "ERROR: B_N0 is expected to be major, but was not in the provided input!\n"; + return 1; + } + + if (example::is_left_major() && (ldBk[0] != 1)) { + std::cerr << "ERROR: B_K0 is expected to be major, but was not in the provided input!\n"; + return 1; + } + + // Convert to `cute::Tuple`s and set up arguments + auto M = make_int_tuple(m.data(), rank_m, 1); + auto dAm = example::make_stride_tuple()>(ldAm.data(), rank_m); + auto dCm = example::make_stride_tuple()>(ldCm.data(), rank_m); + + auto N = make_int_tuple(n.data(), rank_n, 1); + auto dBn = example::make_stride_tuple()>(ldBn.data(), rank_n); + auto dCn = example::make_stride_tuple()>(ldCn.data(), rank_n); + + auto K = make_int_tuple(k.data(), rank_k, 1); + auto dAk = example::make_stride_tuple()>(ldAk.data(), rank_k); + auto dBk = example::make_stride_tuple()>(ldBk.data(), rank_k); + + auto L = make_int_tuple(l.data(), rank_l, 1); + auto dAl = make_int_tuple(ldAl.data(), rank_l, 0); + auto dBl = make_int_tuple(ldBl.data(), rank_l, 0); + auto dCl = make_int_tuple(ldCl.data(), rank_l, 0); + + // Concat tuples to turn it into rank-4 problem shape and rank-3 strides, just like GEMM + auto problem_shape = make_shape(M, N, K, L); + StrideA stride_A = make_stride(dAm, dAk, dAl); + StrideB stride_B = make_stride(dBn, dBk, dBl); + StrideC stride_C = make_stride(dCm, dCn, dCl); + StrideD stride_D = stride_C; + + auto alpha = ElementEpilogue(1.0f); + auto beta = ElementEpilogue(1.0f); + + // + // Allocate and init tensors + // + auto M_size = std::accumulate(std::begin(m), std::end(m), 1, std::multiplies<>{}); + auto N_size = std::accumulate(std::begin(n), std::end(n), 1, std::multiplies<>{}); + auto K_size = std::accumulate(std::begin(k), std::end(k), 1, std::multiplies<>{}); + auto L_size = std::accumulate(std::begin(l), std::end(l), 1, std::multiplies<>{}); + + thrust::host_vector h_A(M_size * K_size * L_size); + thrust::host_vector h_B(N_size * K_size * L_size); + thrust::host_vector h_C(M_size * N_size * L_size); + thrust::host_vector h_D(M_size * N_size * L_size); + + // Note: the cast to int here is to avoid false-negative ref-checks which can + // occur due to floating point arithmetic not being purely associative. + for (auto& a : h_A) a = ElementA(int(4*(rand() / double(RAND_MAX)) - 1)); + for (auto& b : h_B) b = ElementB(int(4*(rand() / double(RAND_MAX)) - 1)); + for (auto& c : h_C) c = ElementC(int(4*(rand() / double(RAND_MAX)) - 1)); + for (auto& d : h_D) d = ElementD(-1); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + thrust::device_vector cutlass_result = h_D; + thrust::device_vector reference_result = h_D; + + // + // Compute GETT + // + auto status = example::gett_kernel( + problem_shape, + d_A.data().get(), stride_A, + d_B.data().get(), stride_B, + ElementAccumulator{}, + d_C.data().get(), stride_C, + cutlass_result.data().get(), stride_D, + alpha, beta); + + if (cutlass::Status::kSuccess != status) { + std::cerr << "ERROR: GETT operator launch failed.\n"; + return 1; + } + + auto cuda_err = cudaDeviceSynchronize(); + if (cudaSuccess != cuda_err) { + std::cerr << "ERROR: GETT operator execution failed. with error :"; + std::cerr << cudaGetErrorString(cuda_err) << "\n"; + return 1; + } + + // + // Verify + // + + cutlass::reference::device::gett( + problem_shape, + d_A.data().get(), stride_A, + d_B.data().get(), stride_B, + ElementAccumulator{}, + d_C.data().get(), stride_C, + reference_result.data().get(), stride_D, + alpha, beta); + + cuda_err = cudaDeviceSynchronize(); + if (cudaSuccess != cuda_err) { + std::cerr << "ERROR: GETT reference execution failed. with error :"; + std::cerr << cudaGetErrorString(cuda_err) << "\n"; + return 1; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + reference_result.data().get(), cutlass_result.data().get(), cutlass_result.size()); + if (passed) { + std::cout << "GETT verification passed.\n"; + return 0; + } + else { + std::cerr << "ERROR: GETT verification failed! Printing detailed stats.\n"; + h_D = reference_result; + thrust::host_vector h_cutlass_result = cutlass_result; + print_relative_error(h_cutlass_result.size(), h_cutlass_result.data(), h_D.data()); + + std::cout << "StrideA: "; print(stride_A); std::cout << '\n'; + std::cout << "StrideB: "; print(stride_B); std::cout << '\n'; + std::cout << "StrideC: "; print(stride_C); std::cout << '\n'; + std::cout << "StrideD: "; print(stride_D); std::cout << '\n'; + return 1; + } +#else + std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM90_SUPPORTED is defined.\n"; + return 0; +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +} diff --git a/examples/51_hopper_gett/CMakeLists.txt b/examples/51_hopper_gett/CMakeLists.txt new file mode 100644 index 00000000..d85bfbd2 --- /dev/null +++ b/examples/51_hopper_gett/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 51_hopper_gett + 51_hopper_gett.cu +) diff --git a/examples/51_hopper_gett/gett_kernel.cuh b/examples/51_hopper_gett/gett_kernel.cuh new file mode 100644 index 00000000..aa6b8357 --- /dev/null +++ b/examples/51_hopper_gett/gett_kernel.cuh @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace example { + +// +// GETT entry point +// +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +cutlass::Status +gett_kernel( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + // TileShape -- GETT configuration + // Specify the number of elements to take from each mode + // BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...) + + // Take 128 from m0, 128 from n0, 64 from k0 + using TileShape = Shape, Shape<_128>, Shape<_64>>; + + /* Other examples: + * Take 32 elements from m0 and 4 elements from m1 + * Take 64 elements from n0 and 2 elements from n1 + * Take 8 elements from k0 and 8 elements from k1 + **/ + // using TileShape = Shape, Shape<_64,_2>, Shape<_8,_8>>; + + using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination< + ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default, + cutlass::FloatRoundStyle::round_to_nearest, ElementC>; + + // No changes are required to the default epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + StrideC, + StrideD, + EpilogueThreadOp>; + + // CollectiveMma for GETTs can be built using the CollectiveBuilders + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 128 / cutlass::sizeof_bits::value, + ElementB, StrideB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + // The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM + using GettKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShapeMNKL, + CollectiveMainloop, + CollectiveEpilogue>; + + using GettOperator = cutlass::gemm::device::GemmUniversalAdapter; + + typename GettOperator::Arguments args { + cutlass::gemm::GemmUniversalMode::kBatched, + problem_shape_mnkl, + ptr_A, stride_a_mkl, + ptr_B, stride_b_nkl, + { ptr_C, stride_c_mnl, ptr_D, stride_d_mnl, {alpha, beta} } + }; + +#if CUTLASS_DEBUG_TRACE_LEVEL > 0 + print("Problem shape:"); + print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n"); + print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n"); + print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n"); + print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n"); + print("TileSape:"); print(TileShape{}); print("\n"); +#endif + + GettOperator op; + return op(args, stream); +} + +} // namespace example diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a063bd81..fe884a5b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -131,6 +131,7 @@ foreach(EXAMPLE 48_hopper_warp_specialized_gemm 49_hopper_gemm_schedules_with_collective_builder 50_hopper_gemm_with_epilogue_swizzle + 51_hopper_gett ) add_subdirectory(${EXAMPLE}) diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index f414ebc4..4e98ea32 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -762,7 +762,17 @@ make_tma_copy(CopyOp, print("layout_tv : "); print(layout_tv); print("\n"); #endif - return TiledCopy, decltype(layout_tv), decltype(cta_tile)>{tma_desc, gmem_stride_bases}; + // If CTA_Tile and SLayout are incompatible, product_each makes sure + // that the TiledCopy generates consistent accesses. + auto cta_tile_tiled = [&]() { + if constexpr (compatible(shape(CTA_Tile{}), shape(SLayout{}))) { + return cta_tile; + } else { + return product_each(cta_tile); + } + }(); + + return TiledCopy, decltype(layout_tv), decltype(cta_tile_tiled)>{tma_desc, gmem_stride_bases}; } // Explicit defaulting diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index c1444a98..580d5ca2 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -61,7 +61,7 @@ template constexpr cute::GMMA::Major tag_to_gmma_major_A() { // MN major mode is only valid for non-TF32 and non-int MMAs - if constexpr (std::is_same_v && + if constexpr (cutlass::gemm::detail::is_mn_major_A() && not std::is_same_v && not std::is_same_v && not std::is_same_v) { @@ -77,7 +77,7 @@ template constexpr cute::GMMA::Major tag_to_gmma_major_B() { // MN major mode is only valid for non-TF32 and non-int MMAs - if constexpr (std::is_same_v && + if constexpr (cutlass::gemm::detail::is_mn_major_B() && not std::is_same_v && not std::is_same_v && not std::is_same_v) { @@ -113,7 +113,7 @@ make_cp_async_gmem_tiled_copy() { // Maximize the number of threads along the gmem major mode to promote coalesced reads // While making sure our thread layout tiles the threadblock tile evenly - if constexpr (cute::size<1>(StrideType{}) == 1) { + if constexpr (cutlass::gemm::detail::is_k_major()) { // K major thread layout for K major gmem constexpr int threads_major = TileSizeK / Alignment; constexpr int threads_minor = ThreadCount / threads_major; @@ -126,7 +126,7 @@ make_cp_async_gmem_tiled_copy() { Stride, _1>>{}, Layout>>{}); } - else if constexpr (cute::size<0>(StrideType{}) == 1) { + else if constexpr (cutlass::gemm::detail::is_mn_major()) { // MN major thread layout for MN major gmem constexpr int threads_major = TileSizeMN / Alignment; constexpr int threads_minor = ThreadCount / threads_major; @@ -257,7 +257,8 @@ struct CollectiveBuilder< not std::is_same_v && // dispatch TN tf32 and int8 kernels only to TMA builder ((sizeof(ElementA) == 2 && sizeof(ElementB) == 2) || - (std::is_same_v && std::is_same_v))> + (cutlass::gemm::detail::is_k_major_A() && + cutlass::gemm::detail::is_k_major_B()))> > { static_assert(is_static::value); static_assert(is_static::value); @@ -346,7 +347,8 @@ struct CollectiveBuilder< ((sizeof(ElementB) * AlignmentB) % detail::tma_alignment_bytes != 0) || // dispatch non-TN tf32 and int8 kernels only to cp_async builder ((sizeof(ElementA) != 2 || sizeof(ElementB) != 2) && - (not std::is_same_v || not std::is_same_v))> + (not cutlass::gemm::detail::is_k_major_A() || + not cutlass::gemm::detail::is_k_major_B()))> > { static_assert(is_static::value); static_assert(is_static::value); diff --git a/include/cutlass/gemm/gemm.h b/include/cutlass/gemm/gemm.h index 4b76101b..0b1fd25b 100644 --- a/include/cutlass/gemm/gemm.h +++ b/include/cutlass/gemm/gemm.h @@ -37,8 +37,7 @@ #include "cutlass/coord.h" #include "cutlass/layout/matrix.h" #include "cute/layout.hpp" -#include "cute/arch/copy_sm90.hpp" - +#include "cute/arch/copy_sm90_tma.hpp" namespace cutlass { namespace gemm { @@ -426,7 +425,9 @@ enum class SharedMemoryClearOption { // For each cutlass::layout, provides its corresponding cute stride types, 64b by default template -struct TagToStrideA {}; +struct TagToStrideA { + using type = L; +}; // Maps to modes [M, K, L] template <> @@ -443,7 +444,9 @@ struct TagToStrideA { }; template -struct TagToStrideB {}; +struct TagToStrideB { + using type = L; +}; // Maps to modes [N, K, L] template <> @@ -479,13 +482,19 @@ using TagToStrideC_t = typename TagToStrideC::type; namespace detail { +template +constexpr bool +is_mn_major() { + // Account for stride types with and without batch mode and batch modes with static zero stride + return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value; +} + // Note : This method can be used for deducing the Layout Tag of A, C, D Matrices template constexpr auto stride_to_layout_tag_A() { - // Account for stride types with and without batch mode and batch modes with static zero stride - if constexpr (cute::size<0>(StrideAC{}) == 1) { // M major + if constexpr (is_mn_major()) { // M major return layout::ColumnMajor{}; } else { // K major @@ -499,8 +508,7 @@ template constexpr auto stride_to_layout_tag_B() { - // Account for stride types with and without batch mode and batch modes with static zero stride - if constexpr (cute::size<0>(StrideB{}) == 1) { // N major + if constexpr (is_mn_major()) { // N major return layout::RowMajor{}; } else { // K major @@ -515,12 +523,12 @@ template constexpr int get_alignment_count_from_gmem_tiled_copy() { // For TMA tiled copies, we know the alignment has to be 128 bits - if constexpr (std::is_base_of_v || - std::is_base_of_v) { + if constexpr ( std::is_base_of_v + || std::is_base_of_v + ) { return 128 / sizeof_bits::value; } - else - { + else { // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN return GmemTiledCopy::NumValSrc; } @@ -551,6 +559,37 @@ using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; template using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; +template +constexpr +bool +is_k_major() { + return ! is_mn_major(); +} + +template +constexpr bool +is_mn_major_A() { + return is_mn_major>(); +} + +template +constexpr bool +is_mn_major_B() { + return is_mn_major>(); +} + +template +constexpr bool +is_k_major_A() { + return is_k_major>(); +} + +template +constexpr bool +is_k_major_B() { + return is_k_major>(); +} + /////////////////////////////////////////////////////////////////////////////// // The following two metafunctions are used to detect whether a `kernel::Gemm` or `kernel::GemmUniversal` diff --git a/tools/util/include/cutlass/util/gett_commandline.hpp b/tools/util/include/cutlass/util/gett_commandline.hpp new file mode 100644 index 00000000..e2a992f8 --- /dev/null +++ b/tools/util/include/cutlass/util/gett_commandline.hpp @@ -0,0 +1,369 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT command line parser to gather semantic modes, their stride order, and extents. +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +namespace cutlass { + +// Output shortcuts +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a << " "; + return os; +} + +struct GettCommandLine { + struct GettProblem { + using extent_type = int; + using stride_type = int64_t; + + // Row modes: appear in A and C/D + std::vector M; + std::vector ldAm; + std::vector ldCm; + + // Column modes: appear in B and C/D + std::vector N; + std::vector ldBn; + std::vector ldCn; + + // Reduction modes: appear in A and B + std::vector K; + std::vector ldAk; + std::vector ldBk; + + // Batch modes: appear in all in/out tensors + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + }; + + static GettProblem + parse(int argc, char const* argv[], bool parse_verbose = false) { + using extent_type = typename GettProblem::extent_type; + using stride_type = typename GettProblem::stride_type; + + cutlass::CommandLine cmd(argc, argv); + + // modeA + std::vector a_mode; + cmd.get_cmd_line_arguments("modeA", a_mode); + + // modeB + std::vector b_mode; + cmd.get_cmd_line_arguments("modeB", b_mode); + + // modeC + std::vector c_mode; + cmd.get_cmd_line_arguments("modeC", c_mode); + + + // mode_sizes + std::map mode_size; + // First, initialize all modes in a, b, c to make sure they're in map + for (char a : a_mode) mode_size[a] = 1; + for (char b : b_mode) mode_size[b] = 1; + for (char c : c_mode) mode_size[c] = 1; + + // Then, overwrite the ones in -extent + std::vector > extent_tokens; + cmd.get_cmd_line_argument_pairs("extents", extent_tokens); + for (auto e : extent_tokens) { + if (std::get<0>(e).size() > 1) { + std::cerr << "ERROR: Mode name must only be 1 character long.\n"; + print_usage(); + exit(1); + } + char label = std::get<0>(e)[0]; + int size = std::stoi(std::get<1>(e)); + mode_size[label] = size; + } + + // Print out symbolic modes and their extents + if (parse_verbose) { + std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n"; + for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n"; + } + + // + // Collect/Compute strides + // + + std::map mode_ldA; + std::map mode_ldB; + std::map mode_ldC; + + { + stride_type current; + + current = 1; + for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; } + + current = 1; + for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; } + + current = 1; + for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; } + } + + // + // Collect mode categories + // + + std::vector row_mode; // rows + std::vector col_mode; // columns + std::vector red_mode; // reductions + std::vector bat_mode; // batches + + { + std::vector a_label = a_mode; + std::vector b_label = b_mode; + std::vector c_label = c_mode; + + std::sort(std::begin(a_label), std::end(a_label)); + std::sort(std::begin(b_label), std::end(b_label)); + std::sort(std::begin(c_label), std::end(c_label)); + + // std::set_intersections to find semantic category of each symbolic mode + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(row_mode)); + + std::set_intersection(std::begin(b_label), std::end(b_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(col_mode)); + + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(b_label), std::end(b_label), + std::back_inserter(red_mode)); + + std::set_intersection(std::begin(row_mode), std::end(row_mode), + std::begin(col_mode), std::end(col_mode), + std::back_inserter(bat_mode)); + + // std::set_difference to remove batch modes from other semantic modes + for (char l : bat_mode) { + row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode)); + col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode)); + red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode)); + } + } + + // Print out the semantic association of each symbolic mode + if (parse_verbose) { + std::cout << " rows : " << row_mode << '\n'; + std::cout << " cols : " << col_mode << '\n'; + std::cout << " reds : " << red_mode << '\n'; + std::cout << " bats : " << bat_mode << '\n'; + } + + // + // Permute modes + // + + // Permute the batched modes to promote coalescing + // Sort the batched modes by min(ldAl,ldBl) and tie-broken by the size + std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) { + return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1]) + < std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + for (char l : bat_mode) { + L.push_back(mode_size[l]); + ldAl.push_back(mode_ldA[l]); + ldBl.push_back(mode_ldB[l]); + ldCl.push_back(mode_ldC[l]); + } + + // Permute the reduction modes to promote coalescing + // Sort the reduction modes by min(ldAk,ldBk) and tie-broken by the size + std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) { + return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1]) + < std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector K; + std::vector ldAk; + std::vector ldBk; + for (char k : red_mode) { + K.push_back(mode_size[k]); + ldAk.push_back(mode_ldA[k]); + ldBk.push_back(mode_ldB[k]); + } + + // Permute the row modes to promote coalescing + // Sort the row modes by min(ldAm,ldCm) and tie-broken by ldAm + std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) { + return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1]) + < std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]); + }); + // Compute sizes and strides of ordered row modes + std::vector M; + std::vector ldAm; + std::vector ldCm; + for (char m : row_mode) { + M.push_back(mode_size[m]); + ldAm.push_back(mode_ldA[m]); + ldCm.push_back(mode_ldC[m]); + } + + // Permute the col modes to promote coalescing + // Sort the col modes by min(ldBn,ldCn) and tie-broken by ldBn + std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) { + return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1]) + < std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]); + }); + // Compute sizes and strides of ordered col modes + std::vector N; + std::vector ldBn; + std::vector ldCn; + for (char n : col_mode) { + N.push_back(mode_size[n]); + ldBn.push_back(mode_ldB[n]); + ldCn.push_back(mode_ldC[n]); + } + + if (parse_verbose) { + std::cout << "C_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " = A_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " * B_"; + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << '\n'; + + int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{}); + int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{}); + int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{}); + int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{}); + + std::cout << " M : (" << M_size << ") "; + for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " "; + std::cout << '\n'; + std::cout << " N : (" << N_size << ") "; + for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " "; + std::cout << '\n'; + std::cout << " K : (" << K_size << ") "; + for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " "; + std::cout << '\n'; + std::cout << " L : (" << L_size << ") "; + for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " "; + std::cout << '\n'; + + std::cout << " ldAm : " << ldAm << '\n'; + std::cout << " ldAk : " << ldAk << '\n'; + std::cout << " ldAl : " << ldAl << '\n'; + std::cout << " ldBn : " << ldBn << '\n'; + std::cout << " ldBk : " << ldBk << '\n'; + std::cout << " ldBl : " << ldBl << '\n'; + std::cout << " ldCm : " << ldCm << '\n'; + std::cout << " ldCn : " << ldCn << '\n'; + std::cout << " ldCl : " << ldCl << '\n'; + } + + return {M, ldAm, ldCm, + N, ldBn, ldCn, + K, ldAk, ldBk, + L, ldAl, ldBl, ldCl}; + } + + static void + print_usage() { + std::cout << + "GETT problem command line parser:\n" + " --modeA=\n" + " A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeB=\n" + " A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeC=\n" + " A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --extents=\n" + " A command delimited list of symbolic mode and its corresponding extent.\n" + " Extents are defaulted to 1 if any are not provided.\n\n" + + "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extent=m:4096,n:4096,k:4096\n"; + } +}; + +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/print_error.hpp b/tools/util/include/cutlass/util/print_error.hpp index f867f88e..0da8ecb0 100644 --- a/tools/util/include/cutlass/util/print_error.hpp +++ b/tools/util/include/cutlass/util/print_error.hpp @@ -60,7 +60,7 @@ struct matrix_inf_norm_result { // and thus are best passed by reference or const reference. template matrix_inf_norm_result -matrix_inf_norm(const cute::Tensor& host_matrix) +matrix_inf_norm(cute::Tensor const& host_matrix) { using std::abs; using error_type = decltype(std::declval().inf_norm); @@ -68,17 +68,14 @@ matrix_inf_norm(const cute::Tensor& host_matrix) error_type inf_norm = 0.0; bool found_nan = false; - const auto shape = host_matrix.shape(); - using index_type = std::decay_t(shape))>; // Computing the infinity norm requires that we be able // to treat the input as a matrix, with rows and columns. - static_assert(std::is_integral_v); - const index_type num_rows = cute::get<0>(shape); - const index_type num_cols = cute::get<1>(shape); + const int64_t num_rows = cute::size<0>(host_matrix); + const int64_t num_cols = cute::size<1>(host_matrix); - for(index_type i = 0; i < num_rows; ++i) { + for(int64_t i = 0; i < num_rows; ++i) { error_type row_abs_sum = 0.0; - for(index_type j = 0; j < num_cols; ++j) { + for(int64_t j = 0; j < num_cols; ++j) { row_abs_sum += abs(host_matrix(i, j)); } if(std::isnan(row_abs_sum)) { @@ -94,39 +91,27 @@ matrix_inf_norm(const cute::Tensor& host_matrix) // Infinity norm of (X - Y). template matrix_inf_norm_result -matrix_diff_inf_norm(const cute::Tensor& X, - const cute::Tensor& Y) +matrix_diff_inf_norm(cute::Tensor const& X, + cute::Tensor const& Y) { using std::abs; using error_type = decltype(std::declval().inf_norm); - const auto X_shape = X.shape(); - const auto Y_shape = Y.shape(); + assert(cute::size<0>(X) == cute::size<0>(Y)); + assert(cute::size<1>(X) == cute::size<1>(Y)); - using index_type = std::decay_t(X_shape))>; // Computing the infinity norm requires that we be able // to treat the input as a matrix, with rows and columns. - static_assert(std::is_integral_v); - const index_type num_rows = cute::get<0>(X_shape); - const index_type num_cols = cute::get<1>(X_shape); - - assert(num_rows == cute::get<0>(Y_shape)); - assert(num_cols == cute::get<1>(Y_shape)); - - auto matrix_ij = [&](const auto& A, std::size_t i, std::size_t j) { - return A(i, j); - }; - auto diff_ij = [&](std::size_t i, std::size_t j) { - return matrix_ij(X, i, j) - matrix_ij(Y, i, j); - }; + const int64_t num_rows = cute::size<0>(X); + const int64_t num_cols = cute::size<1>(X); error_type inf_norm = 0.0; bool found_nan = false; - for(index_type i = 0; i < num_rows; ++i) { + for(int64_t i = 0; i < num_rows; ++i) { error_type row_abs_sum = 0.0; - for(index_type j = 0; j < num_cols; ++j) { - row_abs_sum += abs(diff_ij(i, j)); + for(int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += abs(X(i,j) - Y(i,j)); } if(std::isnan(row_abs_sum)) { found_nan = true; @@ -140,22 +125,22 @@ matrix_diff_inf_norm(const cute::Tensor& X, template + typename EngineType_C, typename LayoutType_C, + typename EngineType_C_ref, typename LayoutType_C_ref> void print_matrix_multiply_mollified_relative_error( - const char A_value_type_name[], - const cute::Tensor& A, - const char B_value_type_name[], - const cute::Tensor& B, - const char C_value_type_name[], - const cute::Tensor& C_computed, - const cute::Tensor& C_expected) + char const A_value_type_name[], + cute::Tensor const& A, + char const B_value_type_name[], + cute::Tensor const& B, + char const C_value_type_name[], + cute::Tensor const& C, + cute::Tensor const& C_ref) { const auto [A_norm, A_has_nan] = matrix_inf_norm(A); const auto [B_norm, B_has_nan] = matrix_inf_norm(B); - const auto [C_norm, C_has_nan] = matrix_inf_norm(C_expected); - const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C_computed, C_expected); + const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref); + const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref); const auto A_norm_times_B_norm = A_norm * B_norm; const auto relative_error = A_norm_times_B_norm == 0.0 ? @@ -164,18 +149,19 @@ print_matrix_multiply_mollified_relative_error( // For expected error bounds, please refer to the LAPACK Users' Guide, // in particular https://netlib.org/lapack/lug/node108.html . // Printing the infinity norm of C is a way to check - // that both the function being tested (C_computed) - // and the reference implementation (C_expected) + // that both the function being tested (C) + // and the reference implementation (C_ref) // don't just do nothing (or fill with zeros). using std::cout; - cout << "Value type of A: " << A_value_type_name << '\n' + using cute::shape; + cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n' + << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' + << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' << std::scientific << "Infinity norm of A: " << A_norm << '\n' - << "Value type of B: " << B_value_type_name << '\n' << "Infinity norm of B: " << B_norm << '\n' - << "Value type of C: " << C_value_type_name << '\n' - << "Infinity norm of C_expected: " << C_norm << '\n' - << "Infinity norm of (C_computed - C_expected): " << diff_norm << '\n'; + << "Infinity norm of C: " << C_norm << '\n' + << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; if(A_norm_times_B_norm == 0.0) { cout << "Mollified relative error: " << relative_error << '\n'; @@ -183,11 +169,12 @@ print_matrix_multiply_mollified_relative_error( cout << "Relative error: " << relative_error << '\n'; } - cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in C_expected? " << (C_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in (C_computed - C_expected)? " - << (diff_has_nan ? "yes" : "no") << '\n'; + if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) { + cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; + } } template @@ -233,3 +220,70 @@ auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) auto X_data_const = const_cast >(X_data); return cute::make_tensor(X_data_const, layout); }; + + +template +double +print_relative_error( + std::size_t n, + T1 const& data, + T2 const& reference, + bool print_verbose = false, + bool print_error = true) { + using std::abs; using std::sqrt; + + // Use either double or complex for error computation + using value_type = cute::remove_cvref_t; + using error_type = std::conditional_t::value, + cute::complex, + double>; + + if (print_verbose) { + std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl; + } + + double eps = 1e-200; + + double tot_error_sq = 0; + double tot_norm_sq = 0; + double tot_ind_rel_err = 0; + double max_ind_rel_err = 0; + for (std::size_t i = 0; i < n; ++i) + { + error_type val = data[i]; + error_type ref = reference[i]; + + double aref = abs(ref); + double diff = abs(ref - val); + double rel_error = diff / (aref + eps); + + // Individual relative error + tot_ind_rel_err += rel_error; + + // Maximum relative error + max_ind_rel_err = std::max(max_ind_rel_err, rel_error); + + // Total relative error + tot_error_sq += diff * diff; + tot_norm_sq += aref * aref; + + if (print_verbose) { + std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl; + } + } + + printf("Vector reference norm: [%.5e]\n", sqrt(tot_norm_sq)); + + double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps)); + if (print_error) + printf("Vector relative error: [%.5e]\n", tot_rel_err); + + double ave_rel_err = tot_ind_rel_err / double(n); + if (print_error) + printf("Average relative error: [%.5e]\n", ave_rel_err); + + if (print_error) + printf("Maximum relative error: [%.5e]\n", max_ind_rel_err); + + return tot_rel_err; +} diff --git a/tools/util/include/cutlass/util/reference/device/gett.hpp b/tools/util/include/cutlass/util/reference/device/gett.hpp new file mode 100644 index 00000000..84b7037e --- /dev/null +++ b/tools/util/include/cutlass/util/reference/device/gett.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT device reference code +*/ +#pragma once + +#include + +namespace cutlass::reference::device { + +template < + class ATensor, + class BTensor, + class CTensor, + class DTensor, + class ElementAccumulator, + class ElementEpilogue> +__global__ static +void +gett_kernel( + DTensor D, + ATensor const A, + BTensor const B, + CTensor const C, + ElementEpilogue alpha, ElementEpilogue beta, + ElementAccumulator acc_init) +{ + using namespace cute; + + static_assert(DTensor::rank == 3, "(M,N,L)"); + static_assert(ATensor::rank == 3, "(M,K,L)"); + static_assert(BTensor::rank == 3, "(N,K,L)"); + static_assert(CTensor::rank == 3, "(M,N,L)"); + + assert(size<0>(A) == size<0>(D)); // M + assert(size<0>(C) == size<0>(D)); // M + assert(size<0>(B) == size<1>(D)); // N + assert(size<1>(C) == size<1>(D)); // N + assert(size<1>(A) == size<1>(B)); // K + assert(size<2>(A) == size<2>(D)); // L + assert(size<2>(B) == size<2>(D)); // L + assert(size<2>(C) == size<2>(D)); // L + + NumericConverter a_converter; + NumericConverter b_converter; + NumericConverter acc_converter; + NumericConverter source_converter; + NumericConverter output_converter; + + // Thread id to each element of D + for (int tid = threadIdx.x + blockDim.x * blockIdx.x; + tid < size(D); + tid += blockDim.x * gridDim.x) { + // (m,n,l) coordinate + auto mnl_coord = idx2crd(tid, product_each(shape(D))); + auto m = get<0>(mnl_coord); + auto n = get<1>(mnl_coord); + auto l = get<2>(mnl_coord); + + auto A_ml = A(m,_,l); + auto B_nl = B(n,_,l); + + ElementAccumulator accum = ElementAccumulator(0); + for (int k = 0; k < size<1>(A); ++k) { + ElementAccumulator a = a_converter(A_ml(k)); + ElementAccumulator b = b_converter(B_nl(k)); + accum += a * b; + } + + ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l))); + D(m,n,l) = output_converter(scaled_output); + } +} + +// Most general version +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +void +gett( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + static_assert(rank(ProblemShapeMNKL{}) == 4); + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto K = get<2>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full tensors + auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L) + auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L) + auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L) + auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L) + + dim3 dimBlock(256); + dim3 dimGrid(240); + gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0)); +} + +} // namespace cutlass::reference::device