/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cute/tensor.hpp" #include "gather_tensor.hpp" namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// template < class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class GatherA_, class GatherB_, class TileScheduler_ = void > class GemmGather { public: // // Type Aliases // using ProblemShape = ProblemShape_; using TileSchedulerTag = TileScheduler_; using TileScheduler = TileScheduler_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; using TiledMma = typename CollectiveMainloop::TiledMma; using ArchTag = typename CollectiveMainloop::ArchTag; using ElementA = typename CollectiveMainloop::ElementA; using StrideA = typename CollectiveMainloop::StrideA; using ElementB = typename CollectiveMainloop::ElementB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using ElementC = typename CollectiveEpilogue::ElementC; using StrideC = typename CollectiveEpilogue::StrideC; using ElementD = typename CollectiveEpilogue::ElementD; using StrideD = typename CollectiveEpilogue::StrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; static_assert(std::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); using GatherA = GatherA_; using GatherB = GatherB_; static constexpr int SharedStorageSize = static_cast(cute::max( sizeof(typename CollectiveMainloop::SharedStorage), sizeof(typename CollectiveEpilogue::SharedStorage))); static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{}); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; // Device side arguments struct Arguments { GemmUniversalMode mode{}; ProblemShape problem_shape{}; MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; GatherA gather_A{}; GatherB gather_B{}; }; // Kernel entry point API struct Params { GemmUniversalMode mode; ProblemShape problem_shape; MainloopParams mainloop; EpilogueParams epilogue; GatherA gather_A{}; GatherB gather_B{}; }; // // Methods // // Convert to underlying arguments. static Params to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; return { args.mode, args.problem_shape, CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), args.gather_A, args.gather_B }; } static Status initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { return Status::kSuccess; } static bool can_implement(Arguments const& args) { return args.mode == GemmUniversalMode::kGemm or (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); } static int get_workspace_size(Arguments const& args) { return 0; } static constexpr dim3 get_grid_shape(Params const& params) { int batch_count = 1; if constexpr (rank(ProblemShape{}) == 4) { batch_count = cute::size<3>(params.problem_shape); } return dim3( cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), batch_count ); } static constexpr dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { using namespace cute; using X = Underscore; // Preconditions CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); auto K = get<2>(problem_shape_MNKL); auto L = get<3>(problem_shape_MNKL); // Preconditions static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); // Get the appropriate blocks for this thread block -- potential for thread block locality int thread_idx = int(threadIdx.x); auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) auto [m_coord, n_coord, l_coord] = blockIdx; auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l) // Represent the full tensors Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l) Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l) // Get batch slice Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) // Slice to get the tiles this thread block is responsible for Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) // Compute tile residues for predication auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape TiledMma tiled_mma; Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) clear(accumulators); auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); int k_tile_count = size<2>(gA); // Perform the collective scoped MMA CollectiveMainloop collective_mma; collective_mma( accumulators, gA, gB, accumulators, k_tile_iter, k_tile_count, residue_mnk, thread_idx, smem_buf ); // Epilogue and write to gD CollectiveEpilogue epilogue{params.epilogue}; epilogue( problem_shape_MNKL, blk_shape, blk_coord_mnkl, accumulators, tiled_mma, residue_mnk, thread_idx, smem_buf ); } }; /////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm::kernel