458 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
		
		
			
		
	
	
			458 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
|   | /***************************************************************************************************
 | ||
|  |  * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|  |  * SPDX-License-Identifier: BSD-3-Clause | ||
|  |  * | ||
|  |  * Redistribution and use in source and binary forms, with or without | ||
|  |  * modification, are permitted provided that the following conditions are met: | ||
|  |  * | ||
|  |  * 1. Redistributions of source code must retain the above copyright notice, this | ||
|  |  * list of conditions and the following disclaimer. | ||
|  |  * | ||
|  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | ||
|  |  * this list of conditions and the following disclaimer in the documentation | ||
|  |  * and/or other materials provided with the distribution. | ||
|  |  * | ||
|  |  * 3. Neither the name of the copyright holder nor the names of its | ||
|  |  * contributors may be used to endorse or promote products derived from | ||
|  |  * this software without specific prior written permission. | ||
|  |  * | ||
|  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
|  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
|  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
|  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
|  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
|  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
|  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
|  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
|  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
|  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|  |  * | ||
|  |  **************************************************************************************************/ | ||
|  | /*! \file
 | ||
|  |     \brief Host reference and operations for Sm90 EVT unit test  | ||
|  | */ | ||
|  | #pragma once
 | ||
|  | #include "gemm_testbed_3x_evt.hpp"
 | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | /// Host references used for testing
 | ||
|  | namespace test::gemm::device { | ||
|  | template<class Gemm, class NodeOp, class ...ChildOp> | ||
|  | using HEVT = HostTreeVisitor<Gemm, NodeOp, ChildOp...>; | ||
|  | 
 | ||
|  | template<class Gemm, class EdgeTuple, class ...Ops> | ||
|  | using HDAG = HostTopoVisitor<Gemm, EdgeTuple, Ops...>; | ||
|  | 
 | ||
|  | template<class Gemm, class InputTree, class OutputTree, class... AuxOutTrees> | ||
|  | using HST = HostSplitTreeVisitor<Gemm, InputTree, OutputTree, AuxOutTrees...>; | ||
|  | 
 | ||
|  | /// D = alpha * acc + beta * C + AuxLoad
 | ||
|  | template<class Gemm, class ElementAux, class LayoutAux> | ||
|  | class HostEVTAuxLoad { | ||
|  | public: | ||
|  |   using ScalarAlpha = HostScalarBroadcast<Gemm, 1>; | ||
|  |   using AccFetchNode = HostAccumulator<Gemm>; | ||
|  |   using AuxLoadNode = HostAuxLoad<Gemm, false, ElementAux, LayoutAux>; | ||
|  |   using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarAlpha, AccFetchNode, AuxLoadNode>; | ||
|  |   using ScalarBeta = HostScalarBroadcast<Gemm, 1>; | ||
|  |   using CLoadNode = HostAuxLoad<Gemm, true>; | ||
|  |   using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>; | ||
|  |   using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>; | ||
|  | }; | ||
|  | 
 | ||
|  | /// D = alpha * acc + beta * C + per-column bias
 | ||
|  | template<class Gemm, class ElementBias> | ||
|  | class HostPerColBias { | ||
|  | public: | ||
|  |   using ScalarAlpha = HostScalarBroadcast<Gemm, 1>; | ||
|  |   using AccFetchNode = HostAccumulator<Gemm>; | ||
|  |   using RowBroadcastNode = HostRowBroadcast<Gemm, ElementBias>; | ||
|  |   using TernaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarAlpha, AccFetchNode, RowBroadcastNode>; | ||
|  |   using ScalarBeta = HostScalarBroadcast<Gemm, 1>; | ||
|  |   using CLoadNode = HostAuxLoad<Gemm, true>; | ||
|  |   using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, TernaryCompute0>; | ||
|  |   using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>; | ||
|  | }; | ||
|  | 
 | ||
|  | /// D = beta * C + Graph(relu(alpha * acc + aux) + aux)
 | ||
|  | /// Testing EVT - DAG structure
 | ||
|  | template<class Gemm> | ||
|  | class HostEVTDAG { | ||
|  | public: | ||
|  |   using ScalarAlpha = HostScalarBroadcast<Gemm, 1>; | ||
|  |   using AccFetchNode = HostAccumulator<Gemm>; | ||
|  |   using AuxLoadNode = HostAuxLoad<Gemm, false, cutlass::half_t, cutlass::layout::RowMajor>; | ||
|  |   using DAGNode = HDAG< | ||
|  |     Gemm, | ||
|  |     cute::tuple< | ||
|  |       cute::tuple<>, // 0. alpha
 | ||
|  |       cute::tuple<>, // 1. acc
 | ||
|  |       cute::tuple<>, // 2. aux load
 | ||
|  |       cute::tuple<cute::_0, cute::_1, cute::_2>, // 3. alpha * acc + aux load
 | ||
|  |       cute::tuple<cute::_3>, // relu(alpha * acc + aux load)
 | ||
|  |       cute::tuple<cute::_2, cute::_4> // relu(alpha * acc + aux load) + aux load
 | ||
|  |     >, | ||
|  |     ScalarAlpha, | ||
|  |     AccFetchNode, | ||
|  |     AuxLoadNode, | ||
|  |     HostCompute<Gemm, cutlass::multiply_add>, | ||
|  |     HostCompute<Gemm, cutlass::epilogue::thread::ReLu>, | ||
|  |     HostCompute<Gemm, cutlass::plus> | ||
|  |   >; | ||
|  |   using ScalarBeta = HostScalarBroadcast<Gemm, 1>; | ||
|  |   using CLoadNode = HostAuxLoad<Gemm, true>; | ||
|  |   using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, DAGNode>; | ||
|  |   using EVTModule = HEVT<HostAuxStore<Gemm, true>, TernaryCompute1>; | ||
|  | }; | ||
|  | 
 | ||
|  | /// EVT = alpha * acc + C
 | ||
|  | /// D = Graph(maximum(EVT + per-row bias, EVT))
 | ||
|  | /// Testing DAG - EVT
 | ||
|  | template<class Gemm> | ||
|  | class HostDAGEVT { | ||
|  | public: | ||
|  |   using EVTNode = HEVT< | ||
|  |     HostAuxStore<Gemm, false, cutlass::half_t, cutlass::layout::RowMajor>, | ||
|  |     HEVT< | ||
|  |       HostCompute<Gemm, cutlass::multiply_add>, | ||
|  |       HostScalarBroadcast<Gemm, 2>, | ||
|  |       HostAccumulator<Gemm>, | ||
|  |       HostAuxLoad<Gemm, true> | ||
|  |     > | ||
|  |   >; | ||
|  |   using EVTModule = HEVT< | ||
|  |     HostAuxStore<Gemm, true>, | ||
|  |     HDAG< | ||
|  |       Gemm, | ||
|  |       cute::tuple< | ||
|  |       cute::tuple<>, // 0. EVT
 | ||
|  |       cute::tuple<>, // 1. per-row bias
 | ||
|  |       cute::tuple<cute::_0, cute::_1>, // 2. EVT + per-row bias
 | ||
|  |       cute::tuple<cute::_0, cute::_2> // 3. maximum(EVT + per-row bias, EVT)
 | ||
|  |       >, | ||
|  |       EVTNode, | ||
|  |       HostColBroadcast<Gemm, cutlass::half_t>, | ||
|  |       HostCompute<Gemm, cutlass::plus>, | ||
|  |       HostCompute<Gemm, cutlass::maximum>  | ||
|  |     > | ||
|  |   >; | ||
|  | }; | ||
|  | 
 | ||
|  | /// Xreduce(alpha * acc + beta * C)
 | ||
|  | template<class Gemm, template<class, template <class> class, class> class ReduceOp> | ||
|  | class HostReduce { | ||
|  | public: | ||
|  |   using ScalarAlpha = HostScalarBroadcast<Gemm, 1>; | ||
|  |   using AccFetchNode = HostAccumulator<Gemm>; | ||
|  |   using BinaryCompute0 = HEVT<HostCompute<Gemm, cutlass::multiplies>, ScalarAlpha, AccFetchNode>; | ||
|  |   using ScalarBeta = HostScalarBroadcast<Gemm, 1>; | ||
|  |   using CLoadNode = HostAuxLoad<Gemm, true>; | ||
|  |   using TernaryCompute1 = HEVT<HostCompute<Gemm, cutlass::multiply_add>, ScalarBeta, CLoadNode, BinaryCompute0>; | ||
|  |   using ReduceNode = HEVT<ReduceOp<Gemm, cutlass::plus, float>, TernaryCompute1>; | ||
|  |   using EVTModule = HEVT<HostAuxStore<Gemm, true>, ReduceNode>; | ||
|  | }; | ||
|  | 
 | ||
|  | // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
 | ||
|  | // if D is fp8 
 | ||
|  | //   D = scale_d * activation(Z)
 | ||
|  | // else
 | ||
|  | //   D = activation(Z)
 | ||
|  | template <class Gemm, template <class> class ActivationFn, class ElementD> | ||
|  | class HostScaledLinCombPerRowBiasEltAct { | ||
|  | public: | ||
|  |   using EVTModule = HEVT< | ||
|  |   HostAuxStore<Gemm, true>, | ||
|  |   HEVT< | ||
|  |     HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>,  // activation(Z) * scaled_d
 | ||
|  |     HEVT< | ||
|  |       HostCompute<Gemm, ActivationFn>, // activation(Z)
 | ||
|  |       HEVT< | ||
|  |         HostCompute<Gemm, cutlass::multiply_add>, | ||
|  |         HostScalarBroadcast<Gemm, 1, 2>, // scale_c * beta
 | ||
|  |         HostAuxLoad<Gemm, true>, // C
 | ||
|  |         HEVT< | ||
|  |           HostCompute<Gemm, cutlass::multiply_add>, | ||
|  |           HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
 | ||
|  |           HostAccumulator<Gemm>, | ||
|  |           HostColBroadcast<Gemm, ElementD>, | ||
|  |         > | ||
|  |       > | ||
|  |     >, | ||
|  |     HostScalarBroadcast<Gemm, 1>, // scale_d
 | ||
|  |   > | ||
|  |   >; | ||
|  | }; | ||
|  | 
 | ||
|  | // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
 | ||
|  | // if D is fp8 
 | ||
|  | //   amax_d = max(abs(elements in activation(Z)))
 | ||
|  | //   D = scale_d * activation(Z)
 | ||
|  | // else
 | ||
|  | //   D = activation(Z)
 | ||
|  | // if Aux is fp8 
 | ||
|  | //   amax_aux = max(abs(elements in Z))
 | ||
|  | //   Aux = scale_aux * Z
 | ||
|  | // else
 | ||
|  | //   Aux = Z
 | ||
|  | template <class Gemm, template <class> class ActivationFn, class ElementD> | ||
|  | class HostScaledLinCombPerRowBiasEltActAmaxAux { | ||
|  | public: | ||
|  |   template <typename T> | ||
|  |   using amax = cutlass::maximum_absolute_value_reduction<T, true>; | ||
|  |   using EVTModule = HEVT< | ||
|  |     HostAuxStore<Gemm, true>, | ||
|  |     HST<Gemm, | ||
|  |       // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
 | ||
|  |       HEVT< | ||
|  |         HostCompute<Gemm, cutlass::multiply_add>, | ||
|  |         HostScalarBroadcast<Gemm, 1, 2>, // scale_c * beta
 | ||
|  |         HostAuxLoad<Gemm, true>, // C
 | ||
|  |         HEVT< | ||
|  |           HostCompute<Gemm, cutlass::multiply_add>, | ||
|  |           HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
 | ||
|  |           HostAccumulator<Gemm>, | ||
|  |           HostColBroadcast<Gemm, ElementD>, | ||
|  |         > | ||
|  |       >, | ||
|  |       // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D))
 | ||
|  |       HEVT< | ||
|  |         HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>, | ||
|  |         HEVT< | ||
|  |           HostScalarReduce<Gemm, amax, float>,  | ||
|  |           HEVT< | ||
|  |             HostCompute<Gemm, ActivationFn>, //activation(Z) * scaled_d
 | ||
|  |             HostAccumulator<Gemm>, // Z
 | ||
|  |           > | ||
|  |         >, | ||
|  |         HostScalarBroadcast<Gemm, 1>, // scale_d
 | ||
|  |       >, | ||
|  |       // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux))
 | ||
|  |       HEVT< | ||
|  |         HostAuxStore<Gemm, false, ElementD, cutlass::layout::RowMajor>, | ||
|  |         HEVT< | ||
|  |           HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>, | ||
|  |           HEVT< | ||
|  |             HostScalarReduce<Gemm, amax, float>, | ||
|  |             HostAccumulator<Gemm> | ||
|  |             >, | ||
|  |           HostScalarBroadcast<Gemm, 1> | ||
|  |         > | ||
|  |       > | ||
|  |     > | ||
|  |   >; | ||
|  | }; | ||
|  | } // namespace test::gemm::device
 | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | namespace cutlass::epilogue { | ||
|  | namespace fusion { | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | /// D = alpha * acc + beta * C + AuxLoad
 | ||
|  | template< | ||
|  |   class EpilogueDescriptor, | ||
|  |   class AuxLoadDescriptor, | ||
|  |   class ElementOutput, | ||
|  |   class ElementCompute, | ||
|  |   class ElementScalar = ElementCompute, | ||
|  |   FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest | ||
|  | > | ||
|  | using Sm90LinCombAuxLoad = | ||
|  |   Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
 | ||
|  |     Sm90ScalarBroadcast<ElementScalar>, // beta
 | ||
|  |     Sm90SrcFetch, // C
 | ||
|  |     Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
 | ||
|  |       Sm90ScalarBroadcast<ElementScalar>, // alpha
 | ||
|  |       Sm90AccFetch, // acc
 | ||
|  |       Sm90AuxLoad< | ||
|  |         AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,  | ||
|  |         typename AuxLoadDescriptor::Element,  | ||
|  |         typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom,  | ||
|  |         typename AuxLoadDescriptor::CopyOpS2R // aux load
 | ||
|  |       > | ||
|  |     > | ||
|  |   >; | ||
|  | 
 | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | /// Example DAG
 | ||
|  | /// beta * C + Graph(alpha * acc + gamma + acc)
 | ||
|  | template< | ||
|  |   typename EpilogueDescriptor, | ||
|  |   typename AuxLoadDescriptor, | ||
|  |   class ElementOutput, | ||
|  |   class ElementCompute, | ||
|  |   class ElementScalar = ElementCompute, | ||
|  |   FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest | ||
|  | > | ||
|  | using Sm90LinCombEVTDAG = | ||
|  |   Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + aux)
 | ||
|  |     Sm90ScalarBroadcast<ElementScalar>, // beta
 | ||
|  |     Sm90SrcFetch, // C
 | ||
|  |     Sm90TopologicalVisitor< | ||
|  |       ElementCompute, | ||
|  |       cute::tuple< | ||
|  |         cute::seq<>, // 0. alpha
 | ||
|  |         cute::seq<>, // 1. acc
 | ||
|  |         cute::seq<>, // 2. aux load
 | ||
|  |         cute::seq<1, 0, 2>, // 3. alpha * acc + aux load
 | ||
|  |         cute::seq<3>, // relu(alpha & acc + aux load)
 | ||
|  |         cute::seq<2, 4> // relu(alpha * acc + aux load) + aux load
 | ||
|  |       >, | ||
|  |       Sm90ScalarBroadcast<ElementScalar>, // alpha
 | ||
|  |       Sm90AccFetch, // acc
 | ||
|  |       Sm90AuxLoad< | ||
|  |         AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,  | ||
|  |         typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride,  | ||
|  |         typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, | ||
|  |       Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>, | ||
|  |       Sm90Compute<cutlass::epilogue::thread::ReLu, ElementCompute, ElementCompute, RoundStyle>, | ||
|  |       Sm90Compute<plus, ElementCompute, ElementCompute, RoundStyle> | ||
|  |     >   | ||
|  |     >; | ||
|  | 
 | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | /// Example DAG
 | ||
|  | /// EVT = alpha * acc + C
 | ||
|  | /// D = Graph(maximum(EVT + per-row bias, EVT))
 | ||
|  | template< | ||
|  |   class EpilogueDescriptor, | ||
|  |   class AuxStoreDescriptor, | ||
|  |   class ElementOutput, | ||
|  |   class ElementCompute, | ||
|  |   class ElementBias = ElementOutput, | ||
|  |   class ElementScalar = ElementCompute, | ||
|  |   FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest | ||
|  | > | ||
|  | using Sm90LinCombDAGEVT = | ||
|  |   Sm90TopologicalVisitor< | ||
|  |     ElementCompute, | ||
|  |     cute::tuple< | ||
|  |       cute::seq<>, | ||
|  |       cute::seq<>, | ||
|  |       cute::seq<1, 0>, | ||
|  |       cute::seq<0, 2> | ||
|  |     >, | ||
|  |     Sm90EVT< | ||
|  |       Sm90AuxStore< | ||
|  |         AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,  | ||
|  |         typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, | ||
|  |         typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, | ||
|  |       Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>, | ||
|  |         Sm90ScalarBroadcast<ElementScalar>, | ||
|  |         Sm90AccFetch, | ||
|  |         Sm90SrcFetch | ||
|  |       > | ||
|  |     >, | ||
|  |     Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias>, | ||
|  |     Sm90Compute<plus, ElementCompute, ElementCompute, RoundStyle>, | ||
|  |     Sm90Compute<maximum, ElementOutput, ElementCompute, RoundStyle> | ||
|  |   >; | ||
|  | 
 | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | /// D = alpha * acc + beta * C + per-column bias
 | ||
|  | template< | ||
|  |   class EpilogueDescriptor, | ||
|  |   class ElementOutput, | ||
|  |   class ElementCompute, | ||
|  |   class ElementBias = ElementOutput, | ||
|  |   class ElementScalar = ElementCompute, | ||
|  |   FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest | ||
|  | > | ||
|  | using Sm90LinCombPerColumnBias = | ||
|  |   Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
 | ||
|  |     Sm90ScalarBroadcast<ElementScalar>, // beta
 | ||
|  |     Sm90SrcFetch, // C
 | ||
|  |     Sm90EVT<Sm90Compute<multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
 | ||
|  |       Sm90ScalarBroadcast<ElementScalar>, // alpha
 | ||
|  |       Sm90AccFetch, // acc
 | ||
|  |       Sm90RowBroadcast< | ||
|  |         ceil_div( | ||
|  |           EpilogueDescriptor::StagesC,  | ||
|  |           size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{})) | ||
|  |         ) + 1,  | ||
|  |         typename EpilogueDescriptor::TileShape,  | ||
|  |         ElementBias | ||
|  |       > | ||
|  |     > | ||
|  |   >; | ||
|  | 
 | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | /// D = per-column reduce(alpha * acc + beta * C)
 | ||
|  | template< | ||
|  |   template <class> class RegReduceFn, | ||
|  |   template <class> class GmemReduceFn, | ||
|  |   class ElementReduce,  | ||
|  |   class CtaTileShapeMNK, | ||
|  |   class ElementOutput, | ||
|  |   class ElementCompute, | ||
|  |   class ElementScalar = ElementCompute, | ||
|  |   FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest | ||
|  | > | ||
|  | using Sm90LinCombPerColumnReduce = | ||
|  |   Sm90EVT<Sm90RowReduction<RegReduceFn, GmemReduceFn, 0, CtaTileShapeMNK, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
 | ||
|  |     Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
 | ||
|  |       Sm90ScalarBroadcast<ElementScalar>, // beta
 | ||
|  |       Sm90SrcFetch, // C
 | ||
|  |       Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
 | ||
|  |         Sm90ScalarBroadcast<ElementScalar>, // alpha
 | ||
|  |         Sm90AccFetch // acc
 | ||
|  |       > | ||
|  |     > | ||
|  |   >; | ||
|  | 
 | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | /// D = per-row reduce(alpha * acc + beta * C)
 | ||
|  | template< | ||
|  |   template <class> class RegReduceFn, | ||
|  |   template <class> class GmemReduceFn, | ||
|  |   class ElementReduce,  | ||
|  |   class CtaTileShapeMNK, | ||
|  |   class ElementOutput, | ||
|  |   class ElementCompute, | ||
|  |   class ElementScalar = ElementCompute, | ||
|  |   FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest | ||
|  | > | ||
|  | using Sm90LinCombPerRowReduce = | ||
|  |   Sm90EVT<Sm90ColReduction<RegReduceFn, GmemReduceFn, 0, CtaTileShapeMNK, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
 | ||
|  |     Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
 | ||
|  |       Sm90ScalarBroadcast<ElementScalar>, // beta
 | ||
|  |       Sm90SrcFetch, // C
 | ||
|  |       Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
 | ||
|  |         Sm90ScalarBroadcast<ElementScalar>, // alpha
 | ||
|  |         Sm90AccFetch // acc
 | ||
|  |       > | ||
|  |     > | ||
|  |   >; | ||
|  | 
 | ||
|  | 
 | ||
|  | //////////////////////////////////////////////////////////////////////////////
 | ||
|  | /// D = scalar reduce(alpha * acc + beta * C)
 | ||
|  | template< | ||
|  |   template <class> class RegReduceFn, | ||
|  |   template <class> class GmemReduceFn, | ||
|  |   class ElementReduce,  | ||
|  |   class ElementOutput, | ||
|  |   class ElementCompute, | ||
|  |   class ElementScalar = ElementCompute, | ||
|  |   FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest | ||
|  | > | ||
|  | using Sm90LinCombScalarReduce = | ||
|  |   Sm90EVT<Sm90ScalarReduction<RegReduceFn, GmemReduceFn, ElementReduce, ElementCompute, RoundStyle>, // per column reduce
 | ||
|  |     Sm90EVT<Sm90Compute<multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + alpha * acc
 | ||
|  |       Sm90ScalarBroadcast<ElementScalar>, // beta
 | ||
|  |       Sm90SrcFetch, // C
 | ||
|  |       Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
 | ||
|  |         Sm90ScalarBroadcast<ElementScalar>, // alpha
 | ||
|  |         Sm90AccFetch // acc
 | ||
|  |       > | ||
|  |     > | ||
|  |   >; | ||
|  | } // namespace fusion
 | ||
|  | 
 | ||
|  | } // namespace cutlass::epilogue
 |