/*************************************************************************************************** * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and batched array variants. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_universal_base.h" #include "cutlass/gemm/kernel/gemm_transpose_operands.h" //////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// template class GemmUniversalAdapter { public: using GemmKernel = GemmKernel_; static bool const kInternalTranspose = platform::is_same::value; using ThreadblockShape = typename GemmKernel::Mma::Shape; using WarpShape = typename GemmKernel::WarpShape; using InstructionShape = typename GemmKernel::InstructionShape; // warp-level, arch-level (instruction), math operator using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; using MathOperator = typename WarpMmaOperator::MathOperator; // Operator class and arch tag extract bottom-up // set it for top-level gemm device-level template using OperatorClass = typename WarpMmaOperator::OperatorClass; using ArchTag = typename WarpMmaOperator::ArchTag; // Type, layout, and complex transform deliberately exchanged with B using MapArguments = kernel::detail::MapArguments< typename GemmKernel::ElementA, typename GemmKernel::LayoutA, GemmKernel::kTransformA, GemmKernel::kAlignmentA, typename GemmKernel::ElementB, typename GemmKernel::LayoutB, GemmKernel::kTransformB, GemmKernel::kAlignmentB, typename GemmKernel::LayoutC, kInternalTranspose >; using ElementA = typename MapArguments::ElementA; using LayoutA = typename MapArguments::LayoutA; static ComplexTransform const kTransformA = MapArguments::kTransformA; static int const kAlignmentA = MapArguments::kAlignmentA; using ElementB = typename MapArguments::ElementB; using LayoutB = typename MapArguments::LayoutB; static ComplexTransform const kTransformB = MapArguments::kTransformB; static int const kAlignmentB = MapArguments::kAlignmentB; using ElementC = typename GemmKernel::ElementC; using LayoutC = typename MapArguments::LayoutC; static int const kAlignmentC = GemmKernel::kAlignmentC; using TensorRefA = TensorRef; using TensorRefB = TensorRef; using TensorRefC = TensorRef; using TensorRefD = TensorRef; static int const kStages = GemmKernel::Mma::kStages; using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; using UnderlyingOperator = GemmUniversalBase; using Arguments = typename UnderlyingOperator::Arguments; private: UnderlyingOperator underlying_operator_; public: /// Constructs the GEMM. GemmUniversalAdapter() { } /// Helper to construct a transposed equivalent for the underying GEMM operator static Arguments to_underlying_arguments(Arguments const &args) { if (kInternalTranspose) { return args.transposed_problem(); } else { return args; } } /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { return UnderlyingOperator::can_implement(to_underlying_arguments(args)); } /// Gets the workspace size static size_t get_workspace_size(Arguments const &args) { return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); } /// Computes the grid shape static dim3 get_grid_shape(Arguments const &args) { return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); } /// Computes the maximum number of active blocks per multiprocessor static int maximum_active_blocks(int smem_capacity = -1) { return UnderlyingOperator::maximum_active_blocks(smem_capacity); } /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments Status update(Arguments const &args, void *workspace = nullptr) { return underlying_operator_.update(to_underlying_arguments(args), workspace); } /// Runs the kernel using initialized state. Status run(cudaStream_t stream = nullptr) { return underlying_operator_.run(stream); } /// Runs the kernel using initialized state. Status operator()(cudaStream_t stream = nullptr) { return run(stream); } /// Runs the kernel using initialized state. Status operator()( Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); } return status; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace device } // namespace gemm } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////