cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h
ANIKET SHIVAM e773429f7e
CUTLASS 2.10 updates (#622)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2022-09-12 21:26:30 -04:00

203 lines
7.3 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmKernel_>
class GemmUniversalAdapter {
public:
using GemmKernel = GemmKernel_;
static bool const kInternalTranspose =
platform::is_same<typename GemmKernel::LayoutC, cutlass::layout::RowMajor>::value;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using WarpShape = typename GemmKernel::WarpShape;
using InstructionShape = typename GemmKernel::InstructionShape;
// warp-level, arch-level (instruction), math operator
using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
// Operator class and arch tag extract bottom-up
// set it for top-level gemm device-level template
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
// Type, layout, and complex transform deliberately exchanged with B
using MapArguments = kernel::detail::MapArguments<
typename GemmKernel::ElementA,
typename GemmKernel::LayoutA,
GemmKernel::kTransformA,
GemmKernel::kAlignmentA,
typename GemmKernel::ElementB,
typename GemmKernel::LayoutB,
GemmKernel::kTransformB,
GemmKernel::kAlignmentB,
typename GemmKernel::LayoutC,
kInternalTranspose
>;
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static int const kAlignmentA = MapArguments::kAlignmentA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
static int const kAlignmentB = MapArguments::kAlignmentB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename MapArguments::LayoutC;
static int const kAlignmentC = GemmKernel::kAlignmentC;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kStages = GemmKernel::Mma::kStages;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using UnderlyingOperator = GemmUniversalBase<GemmKernel>;
using Arguments = typename UnderlyingOperator::Arguments;
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmUniversalAdapter() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static Arguments to_underlying_arguments(Arguments const &args) {
if (kInternalTranspose) {
return args.transposed_problem();
}
else {
return args;
}
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////