/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Functor performing elementwise operations used by epilogues. */ #pragma once #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cute/numeric/int.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace epilogue { namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Applies an element wise operation to all elements within the fragment /// and writes them out to destination storage. template < class StrideC_, class StrideD_, class ThreadEpilogueOp_ > class DefaultTransposedEpilogue { public: // // Type Aliases // // derived types of output thread level operator using ThreadEpilogueOp = ThreadEpilogueOp_; using ElementOutput = typename ThreadEpilogueOp::ElementOutput; using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; using ElementCompute = typename ThreadEpilogueOp::ElementCompute; using ElementScalar = ElementCompute; using ElementC = typename ThreadEpilogueOp::ElementC; using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { }; // Params of epilogue::collective contain the epilogue::thread params struct Params { ElementC const* ptr_C = nullptr; StrideC dC{}; ElementD* ptr_D = nullptr; StrideD dD{}; typename ThreadEpilogueOp::Params thread_params{}; }; // // Methods // template static constexpr Params to_underlying_arguments(Args const& args, void* workspace) { (void) workspace; return {args.epilogue_params}; } CUTLASS_HOST_DEVICE DefaultTransposedEpilogue(Params const& params_) : params(params_) { } template< class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout, class TiledMma, class ResidueMNK > CUTLASS_HOST_DEVICE void operator()( ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, char* smem_buf) { using namespace cute; using X = Underscore; static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "ThreadBlock tile shape must be static"); static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); (void) smem_buf; ThreadEpilogueOp epilogue_op{params.thread_params}; // Separate out problem shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); // Tranpose stride C/D. auto stride_c = make_stride(get<1>(params.dC), get<0>(params.dC), get<2>(params.dC)); auto stride_d = make_stride(get<1>(params.dD), get<0>(params.dD), get<2>(params.dD)); // Represent the full output tensor Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) // Partition source and destination tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) static_assert(is_static::value, "Accumulator layout must be static"); CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), "Source and destination must have the same number of elements."); CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), "Accumulator count must have the same destination element count."); auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); Tensor tCcD = thr_mma.partition_C(cD); // source is needed if (epilogue_op.is_source_needed()) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); } } } // source is not needed, avoid load else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { tCgD(i) = epilogue_op(accumulators(i)); } } } } private: Params params; }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace collective } // namespace epilogue } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////