/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Testbed for Ptr-Array and Grouped GEMM interface */ #pragma once #include #include #include #include #include #include "../../common/cutlass_unit_test.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/distribution.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/gett.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/fusion/operations.hpp" #include "cutlass/complex.h" #include "testbed_utils.h" #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/matrix_coord.h" #include "cutlass/gemm/gemm.h" #include "cute/int_tuple.hpp" #include "cute/layout.hpp" #include "cute/numeric/int.hpp" namespace test { namespace gemm { namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// enum class ScalarLoc { ON_HOST = 0, ON_DEVICE = 1 }; enum class VectorBeta { DISABLED = 0, ENABLED = 1 }; enum class CheckEquality { EXACT = 0, RELATIVE = 1 }; namespace detail{ // Helper classes that take default data type when // the Gemm::EpilogueOutputOp does not have ElementCompute // and ElementScalar. // (e.g. when Sm90TreeVisitor is used as FusionCallbacks) template struct ElementComputeType { using Type = Default; }; template struct ElementComputeType> { using Type = typename Gemm::EpilogueOutputOp::ElementCompute; }; template struct ElementScalarType { using Type = Default; }; template struct ElementScalarType> { using Type = typename Gemm::EpilogueOutputOp::ElementScalar; }; // The maximum swizzle size to use // // This class, like Splits above makes it harder to confuse // the order of arguments of the various run(...) functions in this file. class MaxSwizzleSize { public: MaxSwizzleSize() = default; template && !cute::is_same_v)) > explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} explicit operator int() const { return max_swizzle_size_; } private: int max_swizzle_size_ = 1; }; template auto make_iterator(T* ptr) { using namespace cute; if constexpr (cute::is_subbyte_v) { return subbyte_iterator(ptr); } else { return ptr; } } template struct IsDefaultEpilogue { static constexpr bool value = false; }; template struct IsDefaultEpilogue> { static constexpr bool value = true; }; template struct IsDefaultEpilogue> { static constexpr bool value = true; }; // The number of splits to test. // // This class makes it harder to confuse the order of arguments // of the various run(...) functions in this file. The constructor // is explicit, so one can't just type 42 (or false, which the // compiler unhelpfully turns into 0); one has to type Splits(42). // Splits() picks the default number of splits, 1. // // The conversion-to-int operator (operator int()) MUST be explicit! // Conversion to int MUST require static_cast. // Otherwise, that defeats a key purpose of this class, // which is to catch common errors of confusing the order // of function arguments. class Splits { public: Splits() = default; template && !cute::is_same_v)) > explicit Splits(IntegralNotBool splits) : splits_(splits) {} explicit operator int() const { return splits_; } private: int splits_ = 1; }; // The number of iterations to test. // // This class, like Splits above makes it harder to confuse // the order of arguments of the various run(...) functions in this file. // Iterations() picks the default number of iterations, 20. class Iterations { public: Iterations() = default; template && !cute::is_same_v)) > explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} explicit operator int() const { return iterations_; } private: int iterations_ = 20; }; template bool initialize_tensor( cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { if (dist_kind == cutlass::Distribution::Uniform) { double scope_max, scope_min; int bits_input = cutlass::sizeof_bits::value; if (bits_input == 1) { scope_max = 2; scope_min = 0; } else if (bits_input <= 8) { scope_max = 1; scope_min = -1; } else{ scope_max = 4; scope_min = -4; } cutlass::reference::host::TensorFillRandomUniform( view, seed, scope_max, scope_min, 0); } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); } else if (dist_kind == cutlass::Distribution::Sequential) { cutlass::reference::host::BlockFillSequential( view.data(), view.capacity()); } else if (dist_kind == cutlass::Distribution::AllOnes) { cutlass::reference::host::TensorFill(view, Element(1)); } else { EXPECT_TRUE(false) << "Not implemented"; return false; } return true; } // Looks at Cute Stride to check Row / Column Major template static constexpr bool is_row_or_col_major(){ int stride_0 = int(cute::size<0>(Stride{})); int stride_1 = int(cute::size<1>(Stride{})); int depth = cute::depth(Stride{}); return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); } // // Default MMA input Operands : A , B // template< class ScheduleType_, class Gemm, class ElementA_ = typename Gemm::GemmKernel::ElementA, class ElementB_ = typename Gemm::GemmKernel::ElementB> struct HostCollectiveMainloop { // Kernel data types using ElementA = ElementA_; using StrideA = typename Gemm::GemmKernel::StrideA; using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; using ElementB = ElementB_; using StrideB = typename Gemm::GemmKernel::StrideB; using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; static constexpr bool IsGroupGemm = !cute::is_same_v; using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; using ElementScalingFactor = ElementAccumulator; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; using Arguments = typename Gemm::GemmKernel::MainloopArguments; cutlass::ComplexTransform TransformA = Gemm::kTransformA; cutlass::ComplexTransform TransformB = Gemm::kTransformB; std::vector stride_a_host; std::vector stride_b_host; cutlass::DeviceAllocation stride_a_device; cutlass::DeviceAllocation stride_b_device; typename LayoutTagA::Stride stride_factor_A; typename LayoutTagB::Stride stride_factor_B; cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; std::vector> tensors_A; std::vector> tensors_B; cutlass::DeviceAllocation device_tensors_A; cutlass::DeviceAllocation device_tensors_B; // Whether to use relative equality checks CheckEquality check_relative_equality = CheckEquality::EXACT; uint64_t seed; static constexpr uint64_t kDefaultSeed = 4096; // Note: this limitation comes from testbed / not the library static_assert(is_row_or_col_major(), "ERROR : A Layout is neither Row / Column Major)"); static_assert(is_row_or_col_major(), "ERROR : B Layout is neither Row / Column Major)"); HostCollectiveMainloop( CheckEquality check_relative_equality_ = CheckEquality::EXACT, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed, typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() ): stride_factor_A(stride_factor_A_), stride_factor_B(stride_factor_B_), init_A(init_A_), init_B(init_B_), seed(seed_), check_relative_equality(check_relative_equality_) { } bool initialize(ProblemShapeType problem_shapes) { // // Allocate the GEMM workspace // // for pointer array problem_shapes.groups() is 1 tensors_A.clear(); tensors_B.clear(); stride_a_host.clear(); stride_b_host.clear(); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); for(int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode auto a_coord = cutlass::make_Coord(M, K); // Cutlass has Row/Col major refers to MxK times KxN matrix product, // so the HostTensorB should be treated as KxN in "coord"'s view auto b_coord = cutlass::make_Coord(K, N); tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); // It is possible to randomly initialize to all zeros, so override this with non-zeros // in the upper left corner of each operand. tensors_A[i].host_view().at({0, 0}) = ElementA(1); tensors_B[i].host_view().at({0, 0}) = ElementB(1); tensors_A[i].sync_device(); tensors_B[i].sync_device(); } return true; } Arguments to_args(ProblemShapeType problem_shapes) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); std::vector ptr_A_host(L); std::vector ptr_B_host(L); for (int32_t i = 0; i < L; ++i) { ptr_A_host.at(i) = tensors_A[i].device_data(); ptr_B_host.at(i) = tensors_B[i].device_data(); } device_tensors_A.reset(L); device_tensors_A.copy_from_host(ptr_A_host.data()); device_tensors_B.reset(L); device_tensors_B.copy_from_host(ptr_B_host.data()); stride_a_device.reset(problem_shapes.groups()); stride_a_device.copy_from_host(stride_a_host.data()); stride_b_device.reset(problem_shapes.groups()); stride_b_device.copy_from_host(stride_b_host.data()); Arguments arguments; if constexpr (IsGroupGemm) { arguments = { device_tensors_A.get(), stride_a_device.get(), device_tensors_B.get(), stride_b_device.get() }; } else { arguments = { device_tensors_A.get(), stride_a_host[0], device_tensors_B.get(), stride_b_host[0] }; } return arguments; } auto to_host_args(ProblemShapeType problem_shapes, int batch) { using namespace cute; // // Allocate the GEMM workspace // auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), make_layout(make_shape(M, K, 1), stride_a_host[batch])); auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), make_layout(make_shape(N, K, 1), stride_b_host[batch])); cutlass::reference::host::GettMainloopParams mainloop_params{}; mainloop_params.A = A; mainloop_params.B = B; mainloop_params.transform_A = TransformA; mainloop_params.transform_B = TransformB; return mainloop_params; } void print_tensors(std::ofstream& file, int batch) { file << "A =\n" << tensors_A[batch].host_view() << "\nB =\n" << tensors_B[batch].host_view(); } template < class Element, class Layout > bool equality_check( cutlass::TensorView const& lhs, cutlass::TensorView const& rhs) const { // Factors used for calculating relative equality. CUTLASS's relative-equality // checks in include/cutlass/relatively_equal.h are inspired by // https://floating-point-gui.de/errors/comparison/. This reference suggests using // the minimum normal value of a given type as the nonzero_floor. Element epsilon(static_cast(0.1f)); Element nonzero_floor(std::numeric_limits::min()); if constexpr (!cutlass::is_complex::value) { if (check_relative_equality == CheckEquality::RELATIVE) { return cutlass::reference::host::TensorRelativelyEquals( lhs, rhs, epsilon, nonzero_floor); } else { return cutlass::reference::host::TensorEquals(lhs, rhs); } } else { return cutlass::reference::host::TensorEquals(lhs, rhs); } } bool compare_reference( ProblemShapeType problem_shapes, int batch) { EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); bool passed = true; return passed; } }; template struct HostCollectiveDefaultEpilogue { // fusion types are potentially void if the fusion is not supported // helper so we don't try to construct HostTensor with void type template using non_void_t = cute::conditional_t, U, T>; using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; using kernel = typename Gemm::GemmKernel; using Epilogue = typename kernel::CollectiveEpilogue; using ElementD = typename kernel::ElementD; using StrideD = typename kernel::StrideD; using InternalStrideD = typename kernel::InternalStrideD; using ElementC = non_void_t; using StrideC = typename kernel::StrideC; using InternalStrideC = typename kernel::InternalStrideC; static constexpr bool IsGroupGemm = !cute::is_same_v; using FusionOp = typename Gemm::EpilogueOutputOp; static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(is_row_or_col_major(), "ERROR : C Layout is neither Row / Column Major)"); static_assert(is_row_or_col_major(), "ERROR : D Layout is neither Row / Column Major)"); // Deduce Cutlass Layouts (RowMajor & ColumnMajor) using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors using LayoutTagVector = cutlass::layout::PackedVectorLayout; using ElementAccumulator = typename kernel::ElementAccumulator; using ElementScalingFactor = ElementAccumulator; using ProblemShapeType = typename kernel::ProblemShape; using ElementCompute = typename ElementComputeType::Type; using ElementScalar = typename ElementScalarType::Type; using Arguments = typename Gemm::GemmKernel::EpilogueArguments; /// Initialization cutlass::DeviceAllocation stride_c_device; cutlass::DeviceAllocation stride_d_device; std::vector stride_c_host; std::vector stride_d_host; typename LayoutTagC::Stride stride_factor_C; typename LayoutTagD::Stride stride_factor_D; // Inputs ElementScalar alpha; ElementScalar beta; std::vector> tensors_C; std::vector> tensors_D; std::vector> references_D; cutlass::DeviceAllocation device_tensors_C; cutlass::DeviceAllocation device_tensors_D; // Whether to use relative equality checks CheckEquality check_relative_equality = CheckEquality::EXACT; // Are scalars copied to device memory before kernel launch ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; // If per-row scale is enabled and this is true, beta is passed as a host scalar instead of device vector VectorBeta disable_vector_beta = VectorBeta::DISABLED; cutlass::Distribution::Kind init_C; uint64_t seed; static constexpr uint64_t kDefaultSeed = 4096; HostCollectiveDefaultEpilogue( CheckEquality check_relative_equality_ = CheckEquality::EXACT, ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed ): init_C(init_C_), seed(seed_), stride_factor_C(typename LayoutTagC::Stride()), stride_factor_D(typename LayoutTagD::Stride()), check_relative_equality(check_relative_equality_), use_device_scalars(use_device_scalars_){ } bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { // Initialize Epilogue tensors tensors_C.clear(); tensors_D.clear(); references_D.clear(); stride_c_host.clear(); stride_d_host.clear(); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); for (int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode auto c_coord = cutlass::make_Coord(M, N); tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); tensors_C[i].host_view().at({0, 0}) = ElementC(1); cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); tensors_C[i].sync_device(); tensors_D[i].sync_device(); } alpha = alpha_; beta = beta_; return true; } template < class Element, class Layout > bool equality_check( cutlass::TensorView const& lhs, cutlass::TensorView const& rhs) const { // Factors used for calculating relative equality. CUTLASS's relative-equality // checks in include/cutlass/relatively_equal.h are inspired by // https://floating-point-gui.de/errors/comparison/. This reference suggests using // the minimum normal value of a given type as the nonzero_floor. Element epsilon(static_cast(0.1f)); Element nonzero_floor(std::numeric_limits::min()); if constexpr (!cutlass::is_complex::value) { if (check_relative_equality == CheckEquality::RELATIVE) { return cutlass::reference::host::TensorRelativelyEquals( lhs, rhs, epsilon, nonzero_floor); } else { return cutlass::reference::host::TensorEquals(lhs, rhs); } } else { return cutlass::reference::host::TensorEquals(lhs, rhs); } } bool compare_reference( ProblemShapeType problem_shapes, ElementScalar alpha, ElementScalar beta, int batch) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); tensors_D[batch].sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); if (tensors_D[batch].size() > 1) { EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); } if (references_D[batch].size() > 1) { EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); } bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); if(!passed) { std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); std::vector ptr_C_host(L); std::vector ptr_D_host(L); for (int32_t i = 0; i < L; ++i) { ptr_C_host.at(i) = tensors_C[i].device_data(); ptr_D_host.at(i) = tensors_D[i].device_data(); } device_tensors_C.reset(L); device_tensors_C.copy_from_host(ptr_C_host.data()); device_tensors_D.reset(L); device_tensors_D.copy_from_host(ptr_D_host.data()); stride_c_device.reset(problem_shapes.groups()); stride_c_device.copy_from_host(stride_c_host.data()); stride_d_device.reset(problem_shapes.groups()); stride_d_device.copy_from_host(stride_d_host.data()); Arguments arguments; if constexpr (IsGroupGemm) { arguments = { {alpha, beta}, device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() }; } else { arguments = { {alpha, beta}, device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] }; } return arguments; } auto to_host_args(ProblemShapeType problem_shapes, int batch) { using namespace cute; // // Allocate the GEMM workspace // auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); auto coord_0 = cutlass::make_Coord(0); auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); cutlass::reference::host::GettEpilogueParams< ElementScalar, ElementScalar, ElementAccumulator, ElementCompute, decltype(C), decltype(D)> epilogue_params{}; epilogue_params.C = C; epilogue_params.D = D; epilogue_params.alpha = alpha; epilogue_params.beta = beta; return epilogue_params; } }; template struct HostCollectiveEpilogue { // fusion types are potentially void if the fusion is not supported // helper so we don't try to construct HostTensor with void type template using non_void_t = cute::conditional_t, U, T>; using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; using kernel = typename Gemm::GemmKernel; using Epilogue = typename kernel::CollectiveEpilogue; static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); using ElementD = typename kernel::ElementD; using StrideD = typename kernel::StrideD; using InternalStrideD = typename kernel::InternalStrideD; using ElementC = non_void_t; using StrideC = typename kernel::StrideC; using InternalStrideC = typename kernel::InternalStrideC; static constexpr bool IsGroupGemm = !cute::is_same_v; static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(is_row_or_col_major(), "ERROR : C Layout is neither Row / Column Major)"); static_assert(is_row_or_col_major(), "ERROR : D Layout is neither Row / Column Major)"); // Deduce Cutlass Layouts (RowMajor & ColumnMajor) using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors using LayoutTagVector = cutlass::layout::PackedVectorLayout; using ElementAccumulator = typename kernel::ElementAccumulator; using ElementScalingFactor = ElementAccumulator; using ProblemShapeType = typename kernel::ProblemShape; // // FusionOperation derived types/queries // using EpiloguePolicy = typename Epilogue::DispatchPolicy; static constexpr bool IsLegacy = cute::is_same_v< EpiloguePolicy, cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> >; using FusionOp = typename Gemm::EpilogueOutputOp; static_assert(cute::is_base_of_v); using ElementCompute = typename FusionOp::ElementCompute; using ElementScalar = typename FusionOp::ElementScalar; using ElementBias = non_void_t; using ElementAux = non_void_t; using ElementAmax = non_void_t; using LayoutTagAux = non_void_t; using ActivationFunctor = non_void_t>; static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && (cute::is_same_v || cute::is_same_v); static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && (cute::is_same_v || cute::is_same_v); using Arguments = typename Gemm::GemmKernel::EpilogueArguments; /// Initialization cutlass::DeviceAllocation stride_c_device; cutlass::DeviceAllocation stride_d_device; std::vector stride_c_host; std::vector stride_d_host; typename LayoutTagC::Stride stride_factor_C; typename LayoutTagD::Stride stride_factor_D; // Inputs cutlass::HostTensor alpha; cutlass::HostTensor beta; cutlass::HostTensor scale_A; cutlass::HostTensor scale_B; cutlass::HostTensor scale_C; cutlass::HostTensor scale_D; cutlass::HostTensor scale_Aux; cutlass::HostTensor bias; std::vector> tensors_C; cutlass::DeviceAllocation device_tensors_C; cutlass::HostTensor norm_constant; // Outputs cutlass::HostTensor abs_max_Aux; cutlass::HostTensor abs_max_D; std::vector> tensors_Aux; cutlass::DeviceAllocation device_tensors_Aux; cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; std::vector> tensors_D; std::vector> references_D; cutlass::DeviceAllocation device_tensors_D; // References cutlass::HostTensor reference_dbias; std::vector> references_Aux; cutlass::HostTensor reference_abs_max_Aux; cutlass::HostTensor reference_abs_max_D; // Whether to use relative equality checks CheckEquality check_relative_equality = CheckEquality::EXACT; // Are scalars copied to device memory before kernel launch ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; // If per-row scale is enabled and this is true, beta is passed as a host scalar instead of device vector VectorBeta disable_vector_beta = VectorBeta::DISABLED; // Random distribution with which to initialize the A/B/C/D/Aux scaling factors cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; // Random distribution with which to initialize the bias vector cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; cutlass::Distribution::Kind init_C; uint64_t seed; static constexpr uint64_t kDefaultSeed = 4096; HostCollectiveEpilogue( CheckEquality check_relative_equality_ = CheckEquality::EXACT, ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed ): init_scale(init_scale_), init_bias(init_bias_), init_C(init_C_), seed(seed_), stride_factor_C(typename LayoutTagC::Stride()), stride_factor_D(typename LayoutTagD::Stride()), check_relative_equality(check_relative_equality_), use_device_scalars(use_device_scalars_){ } bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { // Initialize Epilogue tensors tensors_C.clear(); tensors_D.clear(); references_D.clear(); stride_c_host.clear(); stride_d_host.clear(); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); for (int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); auto c_coord = cutlass::make_Coord(M, N); tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); tensors_C[i].host_view().at({0, 0}) = ElementC(1); cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); tensors_C[i].sync_device(); tensors_D[i].sync_device(); } auto scalar_coord = cutlass::make_Coord(1); auto col_vector_coord = cutlass::make_Coord(M); if constexpr (IsPerRowScaleEnabled) { alpha.resize(col_vector_coord); EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); if (disable_vector_beta == VectorBeta::DISABLED) { beta.resize(scalar_coord, false); cutlass::reference::host::TensorFill(beta.host_view(), beta_); } else { beta.resize(col_vector_coord); EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); } } else { alpha.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); beta.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); cutlass::reference::host::TensorFill(beta.host_view(), beta_); } alpha.sync_device(); beta.sync_device(); if constexpr (IsScaleFactorEnabled) { scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); scale_A.sync_device(); scale_B.sync_device(); scale_C.sync_device(); scale_D.sync_device(); } if constexpr (IsBiasEnabled) { bias.resize(col_vector_coord); EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); bias.sync_device(); } if constexpr (IsDeBiasEnabled) { bias.resize(col_vector_coord); reference_dbias.resize(col_vector_coord); cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); bias.sync_device(); } if constexpr (IsAbsMaxEnabledD) { abs_max_D.resize(scalar_coord); // ensure in-place device reductions perform their own initialization cutlass::reference::host::TensorFill(abs_max_D.host_view(), CUTLASS_STL_NAMESPACE::numeric_limits::max()); abs_max_D.sync_device(); reference_abs_max_D.resize(scalar_coord); cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); } tensors_Aux.clear(); references_Aux.clear(); static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxInEnabled)); if constexpr (IsAuxInEnabled) { auto aux_coord = cutlass::make_Coord(M, N); auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); for (int32_t i = 0; i < L; ++i) { tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); EXPECT_TRUE(initialize_tensor(tensors_Aux[i].host_view(), init_C, seed + 2023)); tensors_Aux[i].sync_device(); } stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); } static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled)); if constexpr (IsAuxOutEnabled) { for (int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); auto aux_coord = cutlass::make_Coord(M, N); auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); references_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout, false)); tensors_Aux[i].sync_device(); } stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); if constexpr (IsScaleFactorEnabled) { scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); scale_Aux.sync_device(); } if constexpr (IsAbsMaxEnabledAux) { abs_max_Aux.resize(scalar_coord); // ensure in-place device reductions perform their own initialization cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), CUTLASS_STL_NAMESPACE::numeric_limits::max()); abs_max_Aux.sync_device(); reference_abs_max_Aux.resize(scalar_coord); cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); } } return true; } template < class Element, class Layout > bool equality_check( cutlass::TensorView const& lhs, cutlass::TensorView const& rhs) const { // Factors used for calculating relative equality. CUTLASS's relative-equality // checks in include/cutlass/relatively_equal.h are inspired by // https://floating-point-gui.de/errors/comparison/. This reference suggests using // the minimum normal value of a given type as the nonzero_floor. Element epsilon(static_cast(0.1f)); Element nonzero_floor(std::numeric_limits::min()); if constexpr (!cutlass::is_complex::value) { if (check_relative_equality == CheckEquality::RELATIVE) { return cutlass::reference::host::TensorRelativelyEquals( lhs, rhs, epsilon, nonzero_floor); } else { return cutlass::reference::host::TensorEquals(lhs, rhs); } } else { return cutlass::reference::host::TensorEquals(lhs, rhs); } } bool compare_reference( ProblemShapeType problem_shapes, ElementScalar alpha, ElementScalar beta, int batch) { tensors_D[batch].sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); if (tensors_D[batch].size() > 1) { EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); } if (references_D[batch].size() > 1) { EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); } bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); if(!passed) { std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); std::vector ptr_C_host(L); std::vector ptr_D_host(L); for (int32_t i = 0; i < L; ++i) { ptr_C_host.at(i) = tensors_C[i].device_data(); ptr_D_host.at(i) = tensors_D[i].device_data(); } device_tensors_C.reset(L); device_tensors_C.copy_from_host(ptr_C_host.data()); device_tensors_D.reset(L); device_tensors_D.copy_from_host(ptr_D_host.data()); stride_c_device.reset(problem_shapes.groups()); stride_c_device.copy_from_host(stride_c_host.data()); stride_d_device.reset(problem_shapes.groups()); stride_d_device.copy_from_host(stride_d_host.data()); std::vector ptr_Aux_host(L); if constexpr (IsAuxInEnabled || IsAuxOutEnabled) { for (int32_t i = 0; i < L; ++i) { ptr_Aux_host.at(i) = tensors_Aux[i].device_data(); } device_tensors_Aux.reset(L); device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); } Arguments arguments; if constexpr (IsGroupGemm) { arguments = { {}, device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() }; } else { arguments = { {}, device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] }; } auto &fusion_args = arguments.thread; if constexpr (IsLegacy) { arguments.thread = { alpha.at(coord_0), beta.at(coord_0), alpha.device_data(), beta.device_data() }; arguments.ptr_Bias = bias.device_data(); arguments.ptr_T = device_tensors_Aux.get(); } else { fusion_args.alpha = alpha.at(coord_0); fusion_args.beta = beta.at(coord_0); fusion_args.alpha_ptr = alpha.device_data(); fusion_args.beta_ptr = beta.device_data(); // if disable_vector_beta is true this is nullptr if constexpr (IsScaleFactorEnabled) { fusion_args.scale_a = scale_A.at(coord_0); fusion_args.scale_b = scale_B.at(coord_0); fusion_args.scale_c = scale_C.at(coord_0); fusion_args.scale_d = scale_D.at(coord_0); fusion_args.scale_a_ptr = scale_A.device_data(); fusion_args.scale_b_ptr = scale_B.device_data(); fusion_args.scale_c_ptr = scale_C.device_data(); fusion_args.scale_d_ptr = scale_D.device_data(); } if constexpr (IsBiasEnabled) { fusion_args.bias_ptr = bias.device_data(); } if constexpr (IsDeBiasEnabled) { fusion_args.dbias_ptr = bias.device_data(); } // example of how to set kernel activation arguments // see ActivationFunctor::Arguments in activation.h for definition // if Arguments doesn't exist then fusion_args.activation is empty if constexpr (cute::is_same_v>) { fusion_args.activation.scale = ElementCompute(1); } // Treat Clamp as ReLU if constexpr (cute::is_same_v>) { fusion_args.activation.lower_bound = 0; fusion_args.activation.upper_bound = std::numeric_limits::max(); } if constexpr (IsAbsMaxEnabledD) { fusion_args.amax_D_ptr = abs_max_D.device_data(); } if constexpr (IsAuxInEnabled) { fusion_args.aux_ptr = device_tensors_Aux.get(); fusion_args.dAux = stride_Aux; } if constexpr (IsAuxOutEnabled) { fusion_args.aux_ptr = device_tensors_Aux.get(); fusion_args.dAux = stride_Aux; if constexpr (IsScaleFactorEnabled) { fusion_args.scale_aux = scale_Aux.at(coord_0); fusion_args.scale_aux_ptr = scale_Aux.device_data(); } if constexpr (IsAbsMaxEnabledAux) { fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); } } } return arguments; } auto to_host_args(ProblemShapeType problem_shapes, int batch) { using namespace cute; // // Allocate the GEMM workspace // auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); auto coord_0 = cutlass::make_Coord(0); auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux); auto Aux = [&]() { auto ptr = recast_ptr(nullptr); if (IsAuxInEnabled) { ptr = detail::make_iterator(tensors_Aux[batch].host_data()); } else if (IsAuxOutEnabled) { ptr = detail::make_iterator(references_Aux[batch].host_data()); } return cute::make_tensor(ptr, Aux_layout); }(); auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); cutlass::reference::host::GettEpilogueParams< ElementScalar, ElementScalar, ElementAccumulator, ElementCompute, decltype(C), decltype(D), decltype(Bias), decltype(Aux), decltype(Valpha), decltype(Vbeta), ActivationFunctor > epilogue_params{}; epilogue_params.C = C; epilogue_params.D = D; epilogue_params.alpha = alpha.at(coord_0); epilogue_params.beta = beta.at(coord_0); if constexpr (IsScaleFactorEnabled) { epilogue_params.scale_a = scale_A.at(coord_0); epilogue_params.scale_b = scale_B.at(coord_0); epilogue_params.scale_c = scale_C.at(coord_0); epilogue_params.scale_d = scale_D.at(coord_0); } if constexpr (IsBiasEnabled or IsDeBiasEnabled) { epilogue_params.Bias = Bias; } if constexpr (IsAbsMaxEnabledD) { epilogue_params.abs_max_D = reference_abs_max_D.host_data(); } if constexpr (IsAuxInEnabled) { epilogue_params.Aux = Aux; } if constexpr (IsAuxOutEnabled) { epilogue_params.Aux = Aux; if constexpr (IsScaleFactorEnabled) { epilogue_params.scale_aux = scale_Aux.at(coord_0); } if constexpr (IsAbsMaxEnabledAux) { epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); } } if constexpr (IsPerRowScaleEnabled) { epilogue_params.Valpha = Valpha; if (disable_vector_beta == VectorBeta::ENABLED) { epilogue_params.Vbeta = Vbeta; } } return epilogue_params; } }; template < typename Gemm, template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, bool force_legacy_epilogue = false, typename ElementA = typename Gemm::GemmKernel::ElementA, typename ElementB = typename Gemm::GemmKernel::ElementB > struct TestbedImpl { // Kernel data types using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type using HostCollectiveMainloopType = HostCollectiveMainloop; using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, HostCollectiveDefaultEpilogue, HostCollectiveEpilogue>; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; using ElementCompute = typename ElementComputeType::Type; using ElementScalar = typename ElementScalarType::Type; using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; uint32_t sm_count; // Used to force multi-wave tests for persistent kernel schedules constexpr static int MaxSmCount = 16; static constexpr uint64_t kDefaultSeed = 4096; static constexpr uint32_t mma_promotion_interval = 4; using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; HostCollectiveMainloopType collective_mma_inputs; CollectiveEpilogue collective_epilogue; static constexpr bool IsGroupGemm = CollectiveEpilogue::IsGroupGemm; // // Methods // TestbedImpl( CheckEquality check_relative_equality_ = CheckEquality::EXACT, ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_C_, init_scale_, init_bias_, seed_)) { } TestbedImpl( typename LayoutTagA::Stride stride_factor_A_, typename LayoutTagB::Stride stride_factor_B_, typename LayoutTagC::Stride stride_factor_C_, typename LayoutTagD::Stride stride_factor_D_, CheckEquality check_relative_equality_ = CheckEquality::EXACT, ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = kDefaultSeed ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_C_, init_scale_, init_bias_, seed_)) { } /// Initializes data structures bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { collective_mma_inputs.initialize(problem_shapes); collective_epilogue.initialize(problem_shapes, alpha_, beta_); return true; } /// Compares computed reference with device reference and outputs to a file if incorrect bool compare_reference( ProblemShapeType problem_shapes, ElementScalar alpha, ElementScalar beta, int batch) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); bool passed = collective_mma_inputs.compare_reference(problem_shapes, batch); passed &= collective_epilogue.compare_reference(problem_shapes, alpha, beta, batch); EXPECT_TRUE(passed); if (!passed) { std::stringstream fname; fname << "error_Gemm_device_" << M << "x" << N << "x" << K << "x" << batch << "_" << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; std::ofstream file(fname.str()); file << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << batch << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; collective_mma_inputs.print_tensors(file, batch); collective_epilogue.print_tensors(file, batch); } return passed; } /// Verifies the result is a GEMM bool verify( ProblemShapeType problem_shapes, ElementScalar alpha, ElementScalar beta) { using namespace cute; auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); L = max(problem_shapes.groups(), L); bool passed = true; for (int32_t i = 0; i < L; ++i) { auto mainloop_params = collective_mma_inputs.to_host_args(problem_shapes, i); auto epilogue_params = collective_epilogue.to_host_args(problem_shapes, i); cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); passed &= compare_reference(problem_shapes, alpha, beta, i); } return passed; } /// Determine if the CUDA device is sufficient to run the kernel bool sufficient() { // // Determine SMEM requirements and waive if not satisfied // size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); int device_idx; cudaError_t result = cudaGetDevice(&device_idx); if (result != cudaSuccess) { throw std::runtime_error("cudaGetDevice() API call failed."); } cudaDeviceProp properties; result = cudaGetDeviceProperties(&properties, device_idx); this->sm_count = properties.multiProcessorCount; if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } if (properties.sharedMemPerBlockOptin < smem_size) { printf("failed due to smem_size\n"); printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); return false; } return true; } /// Executes one test bool run( ProblemShapeType problem_shapes, ElementScalar alpha = ElementScalar(1), ElementScalar beta = ElementScalar(0), detail::Iterations iterations = detail::Iterations{} ) { // Fail test if insufficient CUDA device if (!sufficient()) { std::cout << "Test failed due to insufficient CUDA device." << std::endl; return false; } if (!this->initialize(problem_shapes, alpha, beta)) { std::cerr << "Initialization failed \n"; return false; } // // Initialize the GEMM operator // typename Gemm::Arguments arguments; cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); hw_info.sm_count = this->sm_count; typename HostCollectiveMainloopType::Arguments mainloop_args; mainloop_args = collective_mma_inputs.to_args(problem_shapes); if constexpr (IsGroupGemm) { arguments = { cutlass::gemm::GemmUniversalMode::kGrouped, problem_shapes, mainloop_args, collective_epilogue.to_args(problem_shapes), hw_info }; } else { arguments = { cutlass::gemm::GemmUniversalMode::kArray, problem_shapes, mainloop_args, collective_epilogue.to_args(problem_shapes), hw_info }; } Gemm gemm_op; size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); cutlass::Status status = gemm_op.can_implement(arguments); if (status != cutlass::Status::kSuccess) { cudaError_t error = cudaGetLastError(); std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; return false; } // // Run the GEMM // cudaError_t result; status = gemm_op.initialize(arguments, workspace.get()); status = gemm_op.run(); result = cudaDeviceSynchronize(); if (result != cudaSuccess) { EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; return false; } EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); // // Verify // bool passed = this->verify(problem_shapes, alpha, beta); if (!passed) { std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta << "\n"; } return passed; } }; } // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// template < typename Gemm, template class ActivationFunctor = cutlass::epilogue::thread::Identity, bool force_legacy_epilogue = false, typename ElementA = typename Gemm::GemmKernel::ElementA, typename ElementB = typename Gemm::GemmKernel::ElementB > struct Testbed3x { using TestBedImpl = typename detail::TestbedImpl< Gemm, ActivationFunctor, force_legacy_epilogue, ElementA, ElementB >; using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; using ElementAccumulator = typename TestBedImpl::ElementAccumulator; using ElementCompute = typename TestBedImpl::ElementCompute; using ElementScalar = typename TestBedImpl::ElementScalar; using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; static constexpr bool IsGroupGemm = TestBedImpl::IsGroupGemm; // Detail Implementation TestBedImpl impl_; // // Methods // Testbed3x( CheckEquality check_relative_equality_ = CheckEquality::EXACT, ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, VectorBeta disable_vector_beta_ = VectorBeta::DISABLED, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed) : impl_(check_relative_equality_, use_device_scalars_, disable_vector_beta_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} /// Executes one test bool run( typename TestBedImpl::ProblemShapeType problem_shapes, ElementScalar alpha = ElementScalar(1), ElementScalar beta = ElementScalar(0), detail::Iterations iterations = detail::Iterations{} ) { return impl_.run( problem_shapes, alpha, beta, iterations); } }; template < typename Gemm, template class ActivationFunctor = cutlass::epilogue::thread::Identity > bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; Testbed3x testbed(check_relative_equality, ScalarLoc::ON_DEVICE, VectorBeta::DISABLED); int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; int batches[] = {5, 10}; bool passed = true; for (int batch : batches) { for (int m : problem_size_m) { for (int n : problem_size_n) { for (int k : problem_size_k) { if constexpr (Testbed3x::IsGroupGemm) { std::vector problem_sizes_host; cutlass::DeviceAllocation problem_sizes_device; for (int i = 0; i < batch; ++i) { problem_sizes_host.push_back({m, n, k}); } problem_sizes_device.reset(problem_sizes_host.size()); problem_sizes_device.copy_from_host(problem_sizes_host.data()); passed = testbed.run( ProblemShapeType{static_cast(problem_sizes_host.size()), problem_sizes_device.get(), problem_sizes_host.data()}, cutlass::from_real(alpha), cutlass::from_real(beta) ); } else { ProblemShapeType problem_size{{m, n, k, batch}}; passed = testbed.run( problem_size, cutlass::from_real(alpha), cutlass::from_real(beta) ); } if (!passed) { std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << batch << " FAILED.\n"; return false; } } // k } // n } // m } // batch return passed; } } // namespace device } // namespace gemm } // namespace test /////////////////////////////////////////////////////////////////////////////////////////////////