/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Testbed and host reference for EVT unittest */ #pragma once #include "gemm_testbed_3x.hpp" namespace test { namespace gemm { namespace device { /// Host-side tapply, tapply in cute is HOST_DEVICE template constexpr auto tapply(T&& t, F&& f, G&& g, cute::seq) { return g(f(std::get(static_cast(t)))...); } ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT: Base class for EVT Node template < typename Gemm_ > class HostEVTNodeBase { public: using Gemm = Gemm_; using TestBedImpl = typename detail::TestbedImpl; using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Kernel::CollectiveEpilogue; using ElementCompute = typename TestBedImpl::ElementCompute; using ElementScalar = typename TestBedImpl::ElementScalar; using ElementAccumulator = typename Kernel::ElementAccumulator; using ElementC = typename Kernel::ElementC; using ElementD = typename Kernel::ElementD; using LayoutTagC = typename TestBedImpl::LayoutTagC; using LayoutTagD = typename TestBedImpl::LayoutTagD; private: bool _check_relative_equality; // Factors used for calculating relative equality. These default // values are borrowed from those used by default in the CUTLASS // profiler for performing relative equality checks. float _epsilon = 0.05f; float _nonzero_floor = 1.0f / 256.0f; public: HostEVTNodeBase(){} HostEVTNodeBase(bool check_relative_equality): _check_relative_equality(check_relative_equality) { } template < class Element, class Layout > bool equality_check( cutlass::TensorView const& lhs, cutlass::TensorView const& rhs) const { if (_check_relative_equality) { return cutlass::reference::host::TensorRelativelyEquals( lhs, rhs, Element(_epsilon), Element(_nonzero_floor) ); } else { return cutlass::reference::host::TensorEquals(lhs, rhs); } } void* get_tensor_C_ptr() { return nullptr; } void* get_tensor_D_ptr() { return nullptr; } bool compare_reference(std::stringstream& error_ss) { return true; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Accumulator template < typename Gemm > class HostAccumulator: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using TestBedImpl = typename Base::TestBedImpl; using ElementAccumulator = typename Base::ElementAccumulator; using ElementCompute = typename Base::ElementCompute; struct Arguments { }; private: cutlass::NumericConverter accumulator_converter; public: HostAccumulator(){} template HostAccumulator(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) :Base(check_relative_equality) {} ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { return accumulator_converter(acc); } Arguments get_arguments() { return Arguments{}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Scalar Broadcast template < typename Gemm, int Value, int BroadcastCount = 1, template class ReductionFn = cutlass::multiplies > class HostScalarBroadcast : public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using ElementCompute = typename Base::ElementCompute; struct Arguments { ElementCompute scalar[BroadcastCount] = {0}; ElementCompute const* scalar_ptrs[BroadcastCount] = { nullptr }; cute::Stride dScalar{}; }; private: ElementCompute _scalar{}; public: HostScalarBroadcast(){} template HostScalarBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) : Base(check_relative_equality), _scalar(ElementCompute(Value)) {} template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { return _scalar; } bool compare_reference(std::stringstream& error_ss) { error_ss << "Scalar: " << float(_scalar) << "\n\n"; return true; } Arguments get_arguments() { if constexpr (BroadcastCount == 1) return Arguments{{_scalar}, {nullptr}}; else if constexpr (BroadcastCount == 2) return Arguments{{_scalar, _scalar}, {nullptr, nullptr}}; else if constexpr (BroadcastCount == 3) return Arguments{{_scalar, _scalar, _scalar}, {nullptr, nullptr, nullptr}}; else return Arguments{{_scalar}, {nullptr}}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Row Broadcast template < typename Gemm, typename ElementBias_=void > class HostRowBroadcast: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using ElementBias = std::conditional_t, typename Base::ElementC, ElementBias_>; using TestBedImpl = typename Base::TestBedImpl; using ElementCompute = typename Base::ElementCompute; using LayoutTagVector = cutlass::layout::PackedVectorLayout; struct Arguments { ElementBias const* ptr_row = nullptr; ElementBias null_default = ElementBias(0); cute::Stride dRow = {}; }; private: cutlass::NumericConverter _bias_converter; cutlass::HostTensor _bias; int _N; TestBedImpl impl_; public: HostRowBroadcast(){} template HostRowBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) : Base(check_relative_equality), impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); _N = cute::get<1>(problem_shape_MNKL); _bias.resize(cutlass::Coord<1>(_N)); EXPECT_TRUE( detail::initialize_tensor( _bias.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023 ) ); _bias.sync_device(); } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { auto TensorBias = cute::make_tensor(_bias.host_data(), cute::make_layout(cute::make_shape(cute::_1{}, _N))); return _bias_converter(TensorBias(1, n + n_b)); } bool compare_reference(std::stringstream& error_ss) { error_ss << "PerColumnBias = \n" << _bias.host_view() << "\n\n"; return true; } Arguments get_arguments() { return {_bias.device_data()}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Column Broadcast template < typename Gemm, typename ElementBias_=void > class HostColBroadcast: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using ElementBias = std::conditional_t, typename Base::ElementC, ElementBias_>; using TestBedImpl = typename Base::TestBedImpl; using ElementCompute = typename Base::ElementCompute; using LayoutTagVector = cutlass::layout::PackedVectorLayout; struct Arguments { ElementBias const* ptr_row = nullptr; ElementBias null_default = ElementBias(0); cute::Stride dRow = {}; }; private: cutlass::NumericConverter _bias_converter; cutlass::HostTensor _bias; int _M; TestBedImpl impl_; public: HostColBroadcast(){} template HostColBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) : Base(check_relative_equality), impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); _M = cute::get<0>(problem_shape_MNKL); _bias.resize(cutlass::Coord<1>(_M)); EXPECT_TRUE( detail::initialize_tensor( _bias.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023 ) ); _bias.sync_device(); } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { auto TensorBias = cute::make_tensor(_bias.host_data(), cute::make_layout(cute::make_shape(_M, cute::_1{}))); return _bias_converter(TensorBias(m + m_b, 1)); } bool compare_reference(std::stringstream& error_ss) { error_ss << "PerRowBias = \n" << _bias.host_view() << "\n\n"; return true; } Arguments get_arguments() { return {_bias.device_data()}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Aux Load template < typename Gemm, bool isC=false, typename ElementAuxLoad_=void, typename LayoutTagAux_=void > class HostAuxLoad: public HostEVTNodeBase { public: using ElementAuxLoad = std::conditional_t, typename HostEVTNodeBase::ElementC, ElementAuxLoad_>; using LayoutTagAux = std::conditional_t, typename HostEVTNodeBase::LayoutTagC, LayoutTagAux_>; using Base = HostEVTNodeBase; using TestBedImpl = typename Base::TestBedImpl; using ElementCompute = typename Base::ElementCompute; using StrideAux = cutlass::gemm::TagToStrideC_t; struct Arguments_Aux { ElementAuxLoad const *ptr_aux = nullptr; ElementAuxLoad null_default = ElementAuxLoad(0); StrideAux dAux = {}; }; struct Arguments_C {}; using Arguments = cute::conditional_t; private: cutlass::NumericConverter _aux_load_converter; cutlass::HostTensor _tensor_aux_load; int _M, _N, _L; TestBedImpl impl_; StrideAux _stride_aux; public: HostAuxLoad(){} template HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) : Base(check_relative_equality), impl_(impl) { auto problem_shape_NMKL = cute::append<4>(problem_size, 1); auto [_M, _N, K, _L] = problem_shape_NMKL; auto aux_coord = cutlass::make_Coord(_M * _L, _N); _tensor_aux_load.resize( aux_coord, cutlass::layout::Affine2Layout_Factory::layout_factory( aux_coord, typename LayoutTagAux::Stride() ) ); EXPECT_TRUE( detail::initialize_tensor( _tensor_aux_load.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023 ) ); _tensor_aux_load.sync_device(); _stride_aux = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(_M, _N, _L)); } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { auto TensorAuxLoad = cute::make_tensor(_tensor_aux_load.host_data(), cute::make_layout(cute::make_shape(_M, _N, _L), _stride_aux)); return _aux_load_converter(TensorAuxLoad(m + m_b, n + n_b, l)); } bool compare_reference(std::stringstream& error_ss) { if constexpr (!isC) { error_ss << "AuxLoad = \n" << _tensor_aux_load.host_view()<< "\n\n"; } return true; } void* get_tensor_C_ptr() { if constexpr (isC) { return static_cast(_tensor_aux_load.device_data()); } else { return nullptr; } } Arguments get_arguments() { if constexpr (isC) return {}; else return {_tensor_aux_load.device_data(), ElementAuxLoad(0), _stride_aux}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Compute template T* findNonNullPtr(T* first_ptr) { return first_ptr; } template T* findNonNullPtr(T* first_ptr, Args... args) { if (first_ptr) { return first_ptr; } return findNonNullPtr(args...); } template < typename Gemm, template class ComputeOp_ > class HostCompute: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using ElementCompute = typename Base::ElementCompute; using ComputeOp = ComputeOp_; struct Arguments { struct OpArgs {} op; }; private: ComputeOp _op; public: HostCompute(){} template HostCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): Base(check_relative_equality) { } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc, Args... frg_inputs) { return _op(frg_inputs...); } Arguments get_arguments(){ return {}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Unary Compute template < typename Gemm, template class ComputeOp_, typename Child0 > class HostUnaryCompute: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using ElementCompute = typename Base::ElementCompute; using ComputeOp = ComputeOp_; struct Arguments { typename Child0::Arguments child_0_args; struct OpArgs {} op; }; private: ComputeOp _op; Child0 _child_0; public: HostUnaryCompute(){} template HostUnaryCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): Base(check_relative_equality), _child_0(problem_size, impl, check_relative_equality) { } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { ElementCompute child_0_result = _child_0.visit(m, n, l, m_b, n_b, acc); return _op(child_0_result); } Arguments get_arguments(){ return { _child_0.get_arguments(), {}, }; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Aux Store template < typename Gemm, bool isD=false, class ElementAuxStore_=void, typename LayoutTagAux_=void > class HostAuxStore: public HostEVTNodeBase { public: using ElementAuxStore = std::conditional_t, typename HostEVTNodeBase::ElementD, ElementAuxStore_>; using LayoutTagAux = std::conditional_t, typename HostEVTNodeBase::LayoutTagD, LayoutTagAux_>; using Base = HostEVTNodeBase; using TestBedImpl = typename Base::TestBedImpl; using ElementCompute = typename Base::ElementCompute; using StrideAux = cutlass::gemm::TagToStrideC_t; struct Arguments_Aux { struct OpArgs { ElementAuxStore* ptr_aux = nullptr; StrideAux dAux = {}; } op; }; struct Arguments_D {}; using Arguments = cute::conditional_t; private: cutlass::NumericConverter destination_converter; cutlass::HostTensor _tensor_aux_store; cutlass::HostTensor _reference_aux_store; int _M, _N, _L; TestBedImpl impl_; StrideAux _stride_aux; public: HostAuxStore(){} template HostAuxStore(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): Base(check_relative_equality), impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [_M, _N, K, _L] = problem_shape_MNKL; auto aux_coord = cutlass::make_Coord(_M * _L, _N); _tensor_aux_store.resize( aux_coord, cutlass::layout::Affine2Layout_Factory::layout_factory( aux_coord, typename LayoutTagAux::Stride() ) ); _reference_aux_store.resize( aux_coord, cutlass::layout::Affine2Layout_Factory::layout_factory( aux_coord, typename LayoutTagAux::Stride() ) ); _tensor_aux_store.sync_device(); _stride_aux = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(_M, _N, _L)); } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc, ElementCompute child_0_result) { auto TensorAuxStore = cute::make_tensor(static_cast(_reference_aux_store.host_data()), cute::make_layout(cute::make_shape(_M, _N, _L), _stride_aux)); TensorAuxStore(m + m_b, n + n_b, l) = destination_converter(child_0_result); return child_0_result; } bool compare_reference(std::stringstream& error_ss) { // Verify the store node _tensor_aux_store.sync_host(); bool equal = this->equality_check(_reference_aux_store.host_view(), _tensor_aux_store.host_view()); if (!equal) { error_ss << "\n\nReference =\n" << _reference_aux_store.host_view() << "\n\nComputed =\n" << _tensor_aux_store.host_view() << "\n\n"; } return equal; } void* get_tensor_D_ptr() { if constexpr (isD) return static_cast(_tensor_aux_store.device_data()); else return nullptr; } Arguments get_arguments() { if constexpr (isD) { return {}; } else { return {_tensor_aux_store.device_data(), _stride_aux}; } } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Row Reduce template < typename Gemm, template class ReduceFn, typename ElementReduce > class HostRowReduce: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using TestBedImpl = typename Base::TestBedImpl; using ElementCompute = typename Base::ElementCompute; using ElementOutput = typename Base::ElementD; using LayoutTagVector = cutlass::layout::PackedVectorLayout; struct Arguments { struct OpArgs { ElementReduce* ptr_row = nullptr; ElementCompute reduce_identity = 0; cute::Stride dRow = {}; } op; }; private: cutlass::NumericConverter destination_converter; cutlass::HostTensor _tensor_row_reduce; cutlass::HostTensor _reduce_buffer; cutlass::HostTensor _reference_row_reduce; int _N; TestBedImpl impl_; ReduceFn reduce_fn; public: HostRowReduce(){} template HostRowReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): Base(check_relative_equality), impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); _N = cute::get<1>(problem_shape_MNKL); _tensor_row_reduce.resize(cutlass::Coord<1>(_N)); _reference_row_reduce.resize(cutlass::Coord<1>(_N)); _reduce_buffer.resize(cutlass::Coord<1>(_N)); _tensor_row_reduce.sync_device(); } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc, ElementCompute child_0_result) { auto TensorRowReduce = cute::make_tensor(_reduce_buffer.host_data(), cute::make_layout(cute::make_shape(cute::_1{}, _N))); TensorRowReduce(1, n + n_b) = reduce_fn(TensorRowReduce(1, n + n_b), child_0_result); return child_0_result; } bool compare_reference(std::stringstream& error_ss) { // Verify the store node _tensor_row_reduce.sync_host(); auto TensorRowReduce = cute::make_tensor(_reference_row_reduce.host_data(), cute::make_layout(cute::make_shape(cute::_1{}, _N))); auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(), cute::make_layout(cute::make_shape(cute::_1{}, _N))); // Filling the reference tensor with the reduce buffer for (int n = 0; n < _N; n ++) { TensorRowReduce(1, n) = destination_converter(TensorReduceBuffer(1, n)); } bool equal = this->equality_check(_reference_row_reduce.host_view(), _tensor_row_reduce.host_view()); if (!equal) { error_ss << "\n\nRow Reduce Reference =\n" << _reference_row_reduce.host_view() << "\n\nRow Reduce Computed =\n" << _tensor_row_reduce.host_view() << "\n\n"; } return equal; } Arguments get_arguments() { return {_tensor_row_reduce.device_data()}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Column Reduce template < typename Gemm, template class ReduceFn, typename ElementReduce > class HostColumnReduce: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using TestBedImpl = typename Base::TestBedImpl; using ElementCompute = typename Base::ElementCompute; using ElementOutput = typename Base::ElementD; using LayoutTagVector = cutlass::layout::PackedVectorLayout; struct Arguments { struct OpArgs { ElementReduce* ptr_col = nullptr; ElementCompute reduce_identity = 0; cute::Stride dRow = {}; } op; }; private: cutlass::NumericConverter destination_converter; cutlass::HostTensor _tensor_column_reduce; cutlass::HostTensor _reduce_buffer; cutlass::HostTensor _reference_column_reduce; int _M; TestBedImpl impl_; ReduceFn reduce_fn; public: HostColumnReduce(){} template HostColumnReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): Base(check_relative_equality), impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); _M = cute::get<0>(problem_shape_MNKL); _tensor_column_reduce.resize(cutlass::Coord<1>(_M)); _reference_column_reduce.resize(cutlass::Coord<1>(_M)); _reduce_buffer.resize(cutlass::Coord<1>(_M)); _tensor_column_reduce.sync_device(); } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc, ElementCompute child_0_result) { auto TensorColReduce = cute::make_tensor(_reduce_buffer.host_data(), cute::make_layout(cute::make_shape(_M, cute::_1{}))); TensorColReduce(m + m_b, 1) = reduce_fn(TensorColReduce(m + m_b, 1), child_0_result); return child_0_result; } bool compare_reference(std::stringstream& error_ss) { // Verify the store node _tensor_column_reduce.sync_host(); auto TensorColReduce = cute::make_tensor(_reference_column_reduce.host_data(), cute::make_layout(cute::make_shape(_M, cute::_1{}))); auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(), cute::make_layout(cute::make_shape(_M, cute::_1{}))); // Filling the reference tensor with the reduce buffer for (int m = 0; m < _M; m ++) { TensorColReduce(m, 1) = destination_converter(TensorReduceBuffer(m, 1)); } bool equal = this->equality_check(_reference_column_reduce.host_view(), _tensor_column_reduce.host_view()); if (!equal) { error_ss << "\n\nColumn Reduce Reference =\n" << _reference_column_reduce.host_view() << "\n\nColumn Reduce Computed =\n" << _tensor_column_reduce.host_view() << "\n\n"; } return equal; } Arguments get_arguments() { return {_tensor_column_reduce.device_data()}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// EVT - Scalar Reduce template < typename Gemm, template class ReduceFn, typename ElementReduce > class HostScalarReduce: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using TestBedImpl = typename Base::TestBedImpl; using ElementCompute = typename Base::ElementCompute; using ElementOutput = typename Base::ElementD; using LayoutTagVector = cutlass::layout::PackedVectorLayout; struct Arguments { struct OpArgs { ElementReduce* ptr_scalar = nullptr; ElementCompute reduce_identity = 0; cute::Stride dScalar = {}; } op; }; private: cutlass::NumericConverter destination_converter; cutlass::HostTensor _tensor_scalar_reduce; cutlass::HostTensor _reduce_buffer; cutlass::HostTensor _reference_scalar_reduce; ReduceFn reduce_fn; TestBedImpl impl_; public: HostScalarReduce(){} template HostScalarReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): Base(check_relative_equality), impl_(impl) { _tensor_scalar_reduce.resize(cutlass::Coord<1>(1)); _reference_scalar_reduce.resize(cutlass::Coord<1>(1)); _reduce_buffer.resize(cutlass::Coord<1>(1)); _tensor_scalar_reduce.sync_device(); } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc, ElementCompute child_0_result) { auto TensorRowReduce = cute::make_tensor(_reduce_buffer.host_data(), cute::make_layout(cute::make_shape(cute::_1{}))); TensorRowReduce(0) = reduce_fn(TensorRowReduce(0), child_0_result); return child_0_result; } bool compare_reference(std::stringstream& error_ss) { // Verify the store node _tensor_scalar_reduce.sync_host(); auto TensorRowReduce = cute::make_tensor(_reference_scalar_reduce.host_data(), cute::make_layout(cute::make_shape(cute::_1{}))); auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(), cute::make_layout(cute::make_shape(cute::_1{}))); // Filling the reference tensor with the reduce buffer TensorRowReduce(0) = destination_converter(TensorReduceBuffer(0)); bool equal = this->equality_check(_reference_scalar_reduce.host_view(), _tensor_scalar_reduce.host_view()); if (!equal) { error_ss << "\n\nScalar Reduce Reference =\n" << _reference_scalar_reduce.host_view() << "\n\nScalar Reduce Computed =\n" << _tensor_scalar_reduce.host_view() << "\n\n"; } return equal; } Arguments get_arguments() { return {_tensor_scalar_reduce.device_data()}; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Host EVT wrapper /// The ArgumentPack is used to model the alignment when num ops <= 4 template struct ArgumentPack; template struct ArgumentPack { T arg; ArgumentPack(T first): arg(first) {} }; template struct ArgumentPack { First arg; ArgumentPack rest_args; ArgumentPack(First first, Rest... rest) : arg(first), rest_args(rest...) {} }; /// Base class for Host Visitor template struct HostVisitorBase: public HostEVTNodeBase { public: using Base = HostEVTNodeBase; using ElementCompute = typename Base::ElementCompute; using Arguments_struct = ArgumentPack; using Arguments_tuple = cute::tuple; constexpr static int Rm1 = sizeof...(Ops); constexpr static bool cond = Rm1 > 4; using Arguments = cute::conditional_t; std::tuple ops; HostVisitorBase(){} template HostVisitorBase(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) :Base(check_relative_equality), ops(test::gemm::device::tapply(std::tuple{}, [&] (auto&& op) { using Op = cute::remove_cvref_t; return Op(problem_size, impl, check_relative_equality); }, [] (auto&&... _ops) { return std::make_tuple(_ops...); }, cute::make_seq{} )){ } bool compare_reference(std::stringstream& error_ss) { return cute::detail::tapply(ops, [&](auto& op) { return op.compare_reference(error_ss); }, [&] (auto&&... inputs) { return arrayAnd(inputs...); }, cute::make_seq{} ); } void* get_tensor_C_ptr() { return cute::detail::tapply(ops, [&](auto& op) { return op.get_tensor_C_ptr(); }, [&] (auto&&... inputs) { return findNonNullPtr(inputs...); }, cute::make_seq{} ); } void* get_tensor_D_ptr() { return cute::detail::tapply(ops, [&](auto& op) { return op.get_tensor_D_ptr(); }, [&] (auto&&... inputs) { return findNonNullPtr(inputs...); }, cute::make_seq{} ); } Arguments get_arguments() { return test::gemm::device::tapply(ops, [&](auto& op) { return op.get_arguments(); }, [&] (auto&&... args) { if constexpr (Rm1 > 4) { return cute::make_tuple(args...); } else { return Arguments(args...); } }, cute::make_seq{} ); } bool arrayAnd(bool passed) { return passed; } template bool arrayAnd(bool first_passed, Args... passed) { if (first_passed) { return arrayAnd(passed...); } return first_passed; } }; /// Tree-struct visitor template struct HostTreeVisitor: public HostVisitorBase { public: using Gemm = typename NodeOp::Base::Gemm; using Base = HostVisitorBase; using ElementCompute = typename Base::ElementCompute; using Arguments = typename Base::Arguments; constexpr static int Rm1 = sizeof...(ChildOps); HostTreeVisitor(){} template HostTreeVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) :Base(problem_size, impl, check_relative_equality){ } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { return cute::detail::tapply(this->ops, [&] (auto& op) { return op.visit(m, n, l, m_b, n_b, acc); }, [&] (auto&&... frg_inputs) { return std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); }, cute::make_seq{} ); } }; /// General Graph visitor template struct HostTopoVisitor: public HostVisitorBase { public: using Base = HostVisitorBase; using ElementCompute = typename Base::ElementCompute; constexpr static int Rm1 = Base::Rm1; using Arguments = typename Base::Arguments; private: ElementCompute frg_outputs[Rm1]; public: HostTopoVisitor(){} template HostTopoVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) :Base(problem_size, impl, check_relative_equality) { } template ElementCompute visit_( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { frg_outputs[I] = cute::transform_apply(cute::get(EdgeTuple{}), [&] (auto&& _E) { constexpr int e = cute::remove_cvref_t::value; return frg_outputs[e]; }, [&] (auto const&... frg_inputs) { ElementCompute res = std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); return res; } ); if constexpr (I < Rm1 - 1) { return visit_(m, n, l, m_b, n_b, acc); } else { return frg_outputs[I]; } } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { return visit_(m, n, l, m_b, n_b, acc); } }; /// SplitTree visitor template struct HostSplitTreeVisitor: public HostVisitorBase { public: using Base = HostVisitorBase; using ElementCompute = typename Base::ElementCompute; using Arguments = typename Base::Arguments; constexpr static int Rm2 = sizeof...(AuxOutTrees); private: ElementCompute frg_input; public: HostSplitTreeVisitor(){} template HostSplitTreeVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) :Base(problem_size, impl, check_relative_equality) { } template void visitAux( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator frag) { std::get(this->ops).visit(m, n, l, m_b, n_b, frag); if constexpr (I < Rm2 - 1) { return visitAux(m, n, l, m_b, n_b, frag); } else { return; } } template ElementCompute visit( int64_t m, int64_t n, int64_t l, int m_b, int n_b, ElementAccumulator acc) { /// Compute the input tree frg_input = std::get<0>(this->ops).visit(m, n, l, m_b, n_b, acc); /// Compute the aux out tree visitAux(m, n, l, m_b, n_b, frg_input); /// Visit the output tree return std::get(this->ops).visit(m, n, l, m_b, n_b, frg_input); } }; /// Universal testbed for EVT template class Testbed3xEVT { public: // The EVT Module to test using EVTModule = typename EVT::EVTModule; using TestBedImpl = typename detail::TestbedImpl; using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; using ElementAccumulator = typename Kernel::ElementAccumulator; using ElementC = typename Kernel::ElementC; using ElementD = typename Kernel::ElementD; using ProblemShapeType = typename Kernel::ProblemShape; using LayoutTagA = typename TestBedImpl::LayoutTagA; using LayoutTagB = typename TestBedImpl::LayoutTagB; using LayoutTagC = typename TestBedImpl::LayoutTagC; using LayoutTagD = typename TestBedImpl::LayoutTagD; // // Methods // Testbed3xEVT( bool check_relative_equality_, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed ) : impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorBeta::ENABLED, init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), check_relative_equality(check_relative_equality_) { } Testbed3xEVT( cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed ) : impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorBeta::ENABLED, init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), check_relative_equality(false) { } Testbed3xEVT( typename LayoutTagA::Stride stride_factor_A_, typename LayoutTagB::Stride stride_factor_B_, typename LayoutTagC::Stride stride_factor_C_, typename LayoutTagD::Stride stride_factor_D_, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed ) : impl_(stride_factor_A_, stride_factor_B_, stride_factor_C_, stride_factor_D_, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorBeta::ENABLED, init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), check_relative_equality(false) { } /// Initializes data structures void initialize(ProblemShapeType problem_size) { // // Allocate the GEMM workspace for A/B tensor // impl_.initialize(problem_size); } // Detail Implementation TestBedImpl impl_; // Whether to use relative equality checks bool check_relative_equality; bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto M = cute::get<0>(problem_shape_MNKL); auto N = cute::get<1>(problem_shape_MNKL); auto K = cute::get<2>(problem_shape_MNKL); auto L = cute::get<3>(problem_shape_MNKL); auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; /// Reference Kernel static int constexpr kBlockM = 64; static int constexpr kBlockN = 64; #if defined(_OPENMP) #pragma omp parallel for collapse(3) #endif for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { ElementAccumulator acc[kBlockM][kBlockN]; gett_mainloop(mainloop_params, m, n, l, acc); /// Epilogue EVT for (int n_b = 0; n_b < kBlockN; ++n_b) { for (int m_b = 0; m_b < kBlockM; ++m_b) { if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); } } } } } } std::stringstream error_ss; bool passed = host_reference.compare_reference(error_ss); if (!passed) { std::stringstream fname; fname << "error_Gemm_device_" << M << "x" << N << "x" << K << "x" << L << "_" << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; std::ofstream file(fname.str()); file << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L << "\n\n"; file << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() << "\nC =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\n\n"; file << error_ss.str(); } return passed; } bool run( ProblemShapeType problem_size, bool profiling = false, int iterations = 20, int splits = 1) { // Fail test if insufficient CUDA device if (!impl_.sufficient()) { std::cout << "Test failed due to insufficient CUDA device." << std::endl; return false; } // // Initialize the Gemm operator // typename Gemm::Arguments arguments; cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; if (not profiling) { impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); hw_info.sm_count = impl_.sm_count; } else { impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); hw_info.sm_count = impl_.sm_count; } typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; if constexpr (cute::is_same_v) { scheduler_args = { splits }; } /// Initializes data structures /// A/B/C/D Tensor initialize(problem_size); /// Initialize the epilogue arguments EVTModule host_reference(problem_size, impl_, check_relative_equality); arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, { impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b }, { // Epilogue arguments {}, // thread static_cast(host_reference.get_tensor_C_ptr()), impl_.collective_epilogue.stride_c, static_cast(host_reference.get_tensor_D_ptr()), impl_.collective_epilogue.stride_d }, // Epilogue arguments end hw_info, scheduler_args }; // Filling in the thread arguments typename EVTModule::Arguments epilogue_args = host_reference.get_arguments(); std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg)); Gemm gemm_op; size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); cutlass::Status status = gemm_op.can_implement(arguments); if (status != cutlass::Status::kSuccess) { cudaError_t error = cudaGetLastError(); std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; return true; } // // Run the GEMM // if (profiling) { return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); } else { cudaError_t result; status = gemm_op.initialize(arguments, workspace.get()); status = gemm_op.run(); result = cudaDeviceSynchronize(); if (result != cudaSuccess) { EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; return false; } } EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); // // Verify // bool passed = this->verify(problem_size, host_reference); if (!passed) { std::cout << "Error : Failed \n"; } return passed; } }; template bool TestAllEVT(bool check_relative_equality=false) { using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; if constexpr (cute::is_same_v) { problem_size_m.push_back(768); problem_size_n.push_back(768); } constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; Testbed3xEVT testbed(check_relative_equality); bool passed = true; for (int m : problem_size_m) { for (int n : problem_size_n) { for (int k : problem_size_k) { ProblemShapeType problem_size; if constexpr (cute::rank(ProblemShapeType{}) == 4) { problem_size = ProblemShapeType{m, n, k, /* l */ 1}; } else { problem_size = ProblemShapeType{m, n, k}; } passed = testbed.run(problem_size); if (!passed) { return false; } } } } // if we do support batched GEMM, just run one test on it to save on test time if constexpr (cute::rank(ProblemShapeType{}) == 4) { auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; passed = testbed.run( problem_size ); if (!passed) { return false; } } return passed; } } // namespace device } // namespace gemm } // namespace test /////////////////////////////////////////////////////////////////////////////////////////////////