/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Host reference and operations for Sm90 EVT unit test */ #pragma once #include "gemm_testbed_3x_evt.hpp" ////////////////////////////////////////////////////////////////////////////// /// Host references used for testing namespace test::gemm::device { template using HEVT = HostTreeVisitor; template using HDAG = HostTopoVisitor; template using HST = HostSplitTreeVisitor; /// D = alpha * acc + beta * C + AuxLoad template class HostEVTAuxLoad { public: using ScalarAlpha = HostScalarBroadcast; using AccFetchNode = HostAccumulator; using AuxLoadNode = HostAuxLoad; using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; using ScalarBeta = HostScalarBroadcast; using CLoadNode = HostAuxLoad; using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; using EVTModule = HEVT, TernaryCompute1>; }; /// D = alpha * acc + beta * C + per-column bias template class HostPerColBias { public: using ScalarAlpha = HostScalarBroadcast; using AccFetchNode = HostAccumulator; using RowBroadcastNode = HostRowBroadcast; using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; using ScalarBeta = HostScalarBroadcast; using CLoadNode = HostAuxLoad; using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; using EVTModule = HEVT, TernaryCompute1>; }; /// D = beta * C + Graph(relu(alpha * acc + aux) + aux) /// Testing EVT - DAG structure template class HostEVTDAG { public: using ScalarAlpha = HostScalarBroadcast; using AccFetchNode = HostAccumulator; using AuxLoadNode = HostAuxLoad; using DAGNode = HDAG< Gemm, cute::tuple< cute::tuple<>, // 0. alpha cute::tuple<>, // 1. acc cute::tuple<>, // 2. aux load cute::tuple, // 3. alpha * acc + aux load cute::tuple, // relu(alpha * acc + aux load) cute::tuple // relu(alpha * acc + aux load) + aux load >, ScalarAlpha, AccFetchNode, AuxLoadNode, HostCompute, HostCompute, HostCompute >; using ScalarBeta = HostScalarBroadcast; using CLoadNode = HostAuxLoad; using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; using EVTModule = HEVT, TernaryCompute1>; }; /// EVT = alpha * acc + C /// D = Graph(maximum(EVT + per-row bias, EVT)) /// Testing DAG - EVT template class HostDAGEVT { public: using EVTNode = HEVT< HostAuxStore, HEVT< HostCompute, HostScalarBroadcast, HostAccumulator, HostAuxLoad > >; using EVTModule = HEVT< HostAuxStore, HDAG< Gemm, cute::tuple< cute::tuple<>, // 0. EVT cute::tuple<>, // 1. per-row bias cute::tuple, // 2. EVT + per-row bias cute::tuple // 3. maximum(EVT + per-row bias, EVT) >, EVTNode, HostColBroadcast, HostCompute, HostCompute > >; }; /// Xreduce(alpha * acc + beta * C) template class, class> class ReduceOp> class HostReduce { public: using ScalarAlpha = HostScalarBroadcast; using AccFetchNode = HostAccumulator; using BinaryCompute0 = HEVT, ScalarAlpha, AccFetchNode>; using ScalarBeta = HostScalarBroadcast; using CLoadNode = HostAuxLoad; using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; using ReduceNode = HEVT, TernaryCompute1>; using EVTModule = HEVT, ReduceNode>; }; // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias // if D is fp8 // D = scale_d * activation(Z) // else // D = activation(Z) template class ActivationFn, class ElementD> class HostScaledLinCombPerRowBiasEltAct { public: using EVTModule = HEVT< HostAuxStore, HEVT< HostCompute::Op>, // activation(Z) * scaled_d HEVT< HostCompute, // activation(Z) HEVT< HostCompute, HostScalarBroadcast, // scale_c * beta HostAuxLoad, // C HEVT< HostCompute, HostScalarBroadcast, // scale_a * scale_b * alpha HostAccumulator, HostColBroadcast, > > >, HostScalarBroadcast, // scale_d > >; }; // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias // if D is fp8 // amax_d = max(abs(elements in activation(Z))) // D = scale_d * activation(Z) // else // D = activation(Z) // if Aux is fp8 // amax_aux = max(abs(elements in Z)) // Aux = scale_aux * Z // else // Aux = Z template class ActivationFn, class ElementD> class HostScaledLinCombPerRowBiasEltActAmaxAux { public: template using amax = cutlass::maximum_absolute_value_reduction; using EVTModule = HEVT< HostAuxStore, HST, HostScalarBroadcast, // scale_c * beta HostAuxLoad, // C HEVT< HostCompute, HostScalarBroadcast, // scale_a * scale_b * alpha HostAccumulator, HostColBroadcast, > >, // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) HEVT< HostCompute::Op>, HEVT< HostScalarReduce, HEVT< HostCompute, //activation(Z) * scaled_d HostAccumulator, // Z > >, HostScalarBroadcast, // scale_d >, // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) HEVT< HostAuxStore, HEVT< HostCompute::Op>, HEVT< HostScalarReduce, HostAccumulator >, HostScalarBroadcast > > > >; }; } // namespace test::gemm::device ////////////////////////////////////////////////////////////////////////////// namespace cutlass::epilogue { namespace fusion { namespace detail { template struct maximum_with_default_nan_propagation : maximum {}; } // namespace detail ////////////////////////////////////////////////////////////////////////////// /// D = alpha * acc + beta * C + AuxLoad template< class EpilogueDescriptor, class AuxLoadDescriptor, class ElementOutput, class ElementCompute, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombAuxLoad = Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc Sm90AuxLoad< AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R // aux load > > >; ////////////////////////////////////////////////////////////////////////////// /// Example DAG /// beta * C + Graph(alpha * acc + gamma + acc) template< typename EpilogueDescriptor, typename AuxLoadDescriptor, class ElementOutput, class ElementCompute, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombEVTDAG = Sm90EVT, // beta * C + (alpha * acc + aux) Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90TopologicalVisitor< ElementCompute, cute::tuple< cute::seq<>, // 0. alpha cute::seq<>, // 1. acc cute::seq<>, // 2. aux load cute::seq<1, 0, 2>, // 3. alpha * acc + aux load cute::seq<3>, // relu(alpha & acc + aux load) cute::seq<2, 4> // relu(alpha * acc + aux load) + aux load >, Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc Sm90AuxLoad< AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, Sm90Compute, Sm90Compute, Sm90Compute > >; ////////////////////////////////////////////////////////////////////////////// /// Example DAG /// EVT = alpha * acc + C /// D = Graph(maximum(EVT + per-row bias, EVT)) template< class EpilogueDescriptor, class AuxStoreDescriptor, class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombDAGEVT = Sm90TopologicalVisitor< ElementCompute, cute::tuple< cute::seq<>, cute::seq<>, cute::seq<1, 0>, cute::seq<0, 2> >, Sm90EVT< Sm90AuxStore< AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, Sm90EVT, Sm90ScalarBroadcast, Sm90AccFetch, Sm90SrcFetch > >, Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias>, Sm90Compute, Sm90Compute >; ////////////////////////////////////////////////////////////////////////////// /// D = alpha * acc + beta * C + per-column bias template< class EpilogueDescriptor, class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombPerColumnBias = Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc Sm90RowBroadcast< ceil_div( EpilogueDescriptor::StagesC, size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{})) ) + 1, typename EpilogueDescriptor::TileShape, ElementBias > > >; ////////////////////////////////////////////////////////////////////////////// /// D = per-column reduce(alpha * acc + beta * C) template< template class RegReduceFn, template class GmemReduceFn, class ElementReduce, class CtaTileShapeMNK, class ElementOutput, class ElementCompute, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombPerColumnReduce = Sm90EVT, // per column reduce Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc Sm90ScalarBroadcast, // alpha Sm90AccFetch // acc > > >; ////////////////////////////////////////////////////////////////////////////// /// D = per-row reduce(alpha * acc + beta * C) template< template class RegReduceFn, template class GmemReduceFn, class ElementReduce, class CtaTileShapeMNK, class ElementOutput, class ElementCompute, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombPerRowReduce = Sm90EVT, // per column reduce Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc Sm90ScalarBroadcast, // alpha Sm90AccFetch // acc > > >; ////////////////////////////////////////////////////////////////////////////// /// D = scalar reduce(alpha * acc + beta * C) template< template class RegReduceFn, template class GmemReduceFn, class ElementReduce, class ElementOutput, class ElementCompute, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombScalarReduce = Sm90EVT, // per column reduce Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc Sm90ScalarBroadcast, // alpha Sm90AccFetch // acc > > >; } // namespace fusion } // namespace cutlass::epilogue