| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2024-01-17 03:37:22 +08:00
										 |  |  |  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! \file
 | 
					
						
							|  |  |  |   \brief Functor performing elementwise operations used by epilogues. | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/cutlass.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/dispatch_policy.hpp"
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/collective/detail.hpp"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cute/tensor.hpp"
 | 
					
						
							| 
									
										
										
										
											2024-03-20 05:51:04 +08:00
										 |  |  | #include "cute/numeric/numeric_types.hpp"
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | #include "gather_tensor.hpp"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass::epilogue::collective { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Applies an element wise operation to all elements within the fragment
 | 
					
						
							|  |  |  | /// and scatter-writes them out to destination storage.
 | 
					
						
							|  |  |  | /// GatherC and ScatterD are types of user-defined functions that apply the
 | 
					
						
							|  |  |  | /// transoformation of the strided coordinate (e.g. through an index array).
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class StrideC_, | 
					
						
							|  |  |  |   class StrideD_, | 
					
						
							|  |  |  |   class ThreadEpilogueOp_, | 
					
						
							|  |  |  |   class EpilogueSchedule_, | 
					
						
							|  |  |  |   class GatherC_, | 
					
						
							|  |  |  |   class ScatterD_ | 
					
						
							|  |  |  | > | 
					
						
							|  |  |  | class EpilogueGatherScatter { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Type Aliases
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   using EpilogueSchedule = EpilogueSchedule_; | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   // derived types of output thread level operator
 | 
					
						
							|  |  |  |   using ThreadEpilogueOp = ThreadEpilogueOp_; | 
					
						
							|  |  |  |   using ElementOutput = typename ThreadEpilogueOp::ElementOutput; | 
					
						
							|  |  |  |   using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; | 
					
						
							|  |  |  |   using ElementCompute = typename ThreadEpilogueOp::ElementCompute; | 
					
						
							|  |  |  |   using ElementScalar = ElementCompute; | 
					
						
							|  |  |  |   using ElementC = typename ThreadEpilogueOp::ElementC; | 
					
						
							|  |  |  |   using StrideC = StrideC_; | 
					
						
							|  |  |  |   using ElementD = typename ThreadEpilogueOp::ElementD; | 
					
						
							|  |  |  |   using StrideD = StrideD_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Every epilogue needs these two GmemTiledCopy{C,D} aliases.
 | 
					
						
							|  |  |  |   // If you don't know what they should be, just use void.
 | 
					
						
							|  |  |  |   using GmemTiledCopyC = void; | 
					
						
							|  |  |  |   using GmemTiledCopyD = void; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GatherC = GatherC_; | 
					
						
							|  |  |  |   using ScatterD = ScatterD_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   static const int kOutputAlignment = ThreadEpilogueOp::kCount; | 
					
						
							|  |  |  |   using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-09 03:42:12 +08:00
										 |  |  |   static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); | 
					
						
							|  |  |  |   static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   struct SharedStorage { }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Host side epilogue arguments
 | 
					
						
							|  |  |  |   struct Arguments { | 
					
						
							|  |  |  |     typename ThreadEpilogueOp::Params thread_params{}; | 
					
						
							|  |  |  |     ElementC const* ptr_C = nullptr; | 
					
						
							|  |  |  |     StrideC dC{}; | 
					
						
							|  |  |  |     ElementD* ptr_D = nullptr; | 
					
						
							|  |  |  |     StrideD dD{}; | 
					
						
							|  |  |  |     GatherC gather_C{}; | 
					
						
							|  |  |  |     ScatterD scatter_D{}; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Device side epilogue params
 | 
					
						
							|  |  |  |   using Params = Arguments; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Methods
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <class ProblemShape> | 
					
						
							|  |  |  |   static constexpr Params | 
					
						
							|  |  |  |   to_underlying_arguments( | 
					
						
							|  |  |  |       [[maybe_unused]] ProblemShape const& _, | 
					
						
							|  |  |  |       Arguments const& args, | 
					
						
							|  |  |  |       [[maybe_unused]] void* workspace) { | 
					
						
							|  |  |  |     return args; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template<class ProblemShape> | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   static bool | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   can_implement( | 
					
						
							|  |  |  |       [[maybe_unused]] ProblemShape const& problem_shape, | 
					
						
							|  |  |  |       [[maybe_unused]] Arguments const& args) { | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   EpilogueGatherScatter(Params const& params_) : params(params_) { } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template< | 
					
						
							|  |  |  |     class ProblemShapeMNKL, | 
					
						
							|  |  |  |     class BlockShapeMNK, | 
					
						
							|  |  |  |     class BlockCoordMNKL, | 
					
						
							|  |  |  |     class FrgEngine, class FrgLayout, | 
					
						
							|  |  |  |     class TiledMma, | 
					
						
							|  |  |  |     class ResidueMNK | 
					
						
							|  |  |  |   > | 
					
						
							|  |  |  |   CUTLASS_DEVICE void | 
					
						
							|  |  |  |   operator()( | 
					
						
							|  |  |  |       ProblemShapeMNKL problem_shape_mnkl, | 
					
						
							|  |  |  |       BlockShapeMNK blk_shape_MNK, | 
					
						
							|  |  |  |       BlockCoordMNKL blk_coord_mnkl, | 
					
						
							|  |  |  |       cute::Tensor<FrgEngine, FrgLayout> const& accumulators, | 
					
						
							|  |  |  |       TiledMma tiled_mma, | 
					
						
							|  |  |  |       ResidueMNK residue_mnk, | 
					
						
							|  |  |  |       int thread_idx, | 
					
						
							|  |  |  |       char* smem_buf) | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     using namespace cute; | 
					
						
							|  |  |  |     using X = Underscore; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-09 03:42:12 +08:00
										 |  |  |     static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static"); | 
					
						
							| 
									
										
										
										
											2023-12-09 03:42:12 +08:00
										 |  |  |     static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); | 
					
						
							|  |  |  |     static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     (void) smem_buf; | 
					
						
							|  |  |  |     ThreadEpilogueOp epilogue_op{params.thread_params}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Separate out problem shape for convenience
 | 
					
						
							|  |  |  |     auto M = get<0>(problem_shape_mnkl); | 
					
						
							|  |  |  |     auto N = get<1>(problem_shape_mnkl); | 
					
						
							|  |  |  |     auto L = get<3>(problem_shape_mnkl); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC); | 
					
						
							|  |  |  |     auto stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Represent the full output tensor
 | 
					
						
							|  |  |  |     Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C);  // (m,n,l)
 | 
					
						
							|  |  |  |     Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{});    // (BLK_M,BLK_N,m,n,l)
 | 
					
						
							|  |  |  |     Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{});    // (BLK_M,BLK_N,m,n,l)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Slice to get the tile this CTA is responsible for
 | 
					
						
							|  |  |  |     auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; | 
					
						
							|  |  |  |     Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord);                                                 // (BLK_M,BLK_N)
 | 
					
						
							|  |  |  |     Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord);                                                 // (BLK_M,BLK_N)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Partition source and destination tiles to match the accumulator partitioning
 | 
					
						
							|  |  |  |     auto thr_mma = tiled_mma.get_thread_slice(thread_idx); | 
					
						
							|  |  |  |     Tensor tCgD = thr_mma.partition_C(gD);                                       // (VEC,THR_M,THR_N)
 | 
					
						
							|  |  |  |     Tensor tCgC = thr_mma.partition_C(gC);                                       // (VEC,THR_M,THR_N)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static"); | 
					
						
							|  |  |  |     CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), | 
					
						
							|  |  |  |         "Source and destination must have the same number of elements."); | 
					
						
							|  |  |  |     CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), | 
					
						
							|  |  |  |         "Accumulator count must have the same destination element count."); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Make an identity coordinate tensor for predicating our output MN tile
 | 
					
						
							|  |  |  |     auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); | 
					
						
							|  |  |  |     Tensor tCcD = thr_mma.partition_C(cD); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // source is needed
 | 
					
						
							|  |  |  |     if (epilogue_op.is_source_needed()) { | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int i = 0; i < size(accumulators); ++i) { | 
					
						
							|  |  |  |         if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { | 
					
						
							|  |  |  |           tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     // source is not needed, avoid load
 | 
					
						
							|  |  |  |     else { | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int i = 0; i < size(accumulators); ++i) { | 
					
						
							|  |  |  |         if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { | 
					
						
							|  |  |  |           tCgD(i) = epilogue_op(accumulators(i)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  |   Params params; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace cutlass::epilogue::collective
 | 
					
						
							|  |  |  | 
 |