203 lines
7.3 KiB
C++
203 lines
7.3 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*!
|
|
\file
|
|
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
|
batched array variants.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/gemm/device/gemm_universal_base.h"
|
|
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace gemm {
|
|
namespace device {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename GemmKernel_>
|
|
class GemmUniversalAdapter {
|
|
public:
|
|
|
|
using GemmKernel = GemmKernel_;
|
|
|
|
static bool const kInternalTranspose =
|
|
platform::is_same<typename GemmKernel::LayoutC, cutlass::layout::RowMajor>::value;
|
|
|
|
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
|
using WarpShape = typename GemmKernel::WarpShape;
|
|
using InstructionShape = typename GemmKernel::InstructionShape;
|
|
|
|
// warp-level, arch-level (instruction), math operator
|
|
using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator;
|
|
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
|
using MathOperator = typename WarpMmaOperator::MathOperator;
|
|
|
|
// Operator class and arch tag extract bottom-up
|
|
// set it for top-level gemm device-level template
|
|
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
|
using ArchTag = typename WarpMmaOperator::ArchTag;
|
|
|
|
// Type, layout, and complex transform deliberately exchanged with B
|
|
using MapArguments = kernel::detail::MapArguments<
|
|
typename GemmKernel::ElementA,
|
|
typename GemmKernel::LayoutA,
|
|
GemmKernel::kTransformA,
|
|
GemmKernel::kAlignmentA,
|
|
typename GemmKernel::ElementB,
|
|
typename GemmKernel::LayoutB,
|
|
GemmKernel::kTransformB,
|
|
GemmKernel::kAlignmentB,
|
|
typename GemmKernel::LayoutC,
|
|
kInternalTranspose
|
|
>;
|
|
|
|
using ElementA = typename MapArguments::ElementA;
|
|
using LayoutA = typename MapArguments::LayoutA;
|
|
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
|
static int const kAlignmentA = MapArguments::kAlignmentA;
|
|
|
|
using ElementB = typename MapArguments::ElementB;
|
|
using LayoutB = typename MapArguments::LayoutB;
|
|
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
|
static int const kAlignmentB = MapArguments::kAlignmentB;
|
|
|
|
using ElementC = typename GemmKernel::ElementC;
|
|
using LayoutC = typename MapArguments::LayoutC;
|
|
static int const kAlignmentC = GemmKernel::kAlignmentC;
|
|
|
|
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
|
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
|
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
|
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
|
|
|
static int const kStages = GemmKernel::Mma::kStages;
|
|
|
|
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
|
using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator;
|
|
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
|
|
|
using UnderlyingOperator = GemmUniversalBase<GemmKernel>;
|
|
using Arguments = typename UnderlyingOperator::Arguments;
|
|
|
|
private:
|
|
|
|
UnderlyingOperator underlying_operator_;
|
|
|
|
public:
|
|
|
|
/// Constructs the GEMM.
|
|
GemmUniversalAdapter() { }
|
|
|
|
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
|
static Arguments to_underlying_arguments(Arguments const &args) {
|
|
if (kInternalTranspose) {
|
|
return args.transposed_problem();
|
|
}
|
|
else {
|
|
return args;
|
|
}
|
|
}
|
|
|
|
/// Determines whether the GEMM can execute the given problem.
|
|
static Status can_implement(Arguments const &args) {
|
|
|
|
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
|
}
|
|
|
|
/// Gets the workspace size
|
|
static size_t get_workspace_size(Arguments const &args) {
|
|
|
|
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
|
}
|
|
|
|
/// Computes the grid shape
|
|
static dim3 get_grid_shape(Arguments const &args) {
|
|
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
|
|
}
|
|
|
|
/// Computes the maximum number of active blocks per multiprocessor
|
|
static int maximum_active_blocks(int smem_capacity = -1) {
|
|
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
|
|
}
|
|
|
|
/// Initializes GEMM state from arguments.
|
|
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
|
|
|
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
|
}
|
|
|
|
/// Lightweight update given a subset of arguments
|
|
Status update(Arguments const &args, void *workspace = nullptr) {
|
|
|
|
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
|
}
|
|
|
|
/// Runs the kernel using initialized state.
|
|
Status run(cudaStream_t stream = nullptr) {
|
|
|
|
return underlying_operator_.run(stream);
|
|
}
|
|
|
|
/// Runs the kernel using initialized state.
|
|
Status operator()(cudaStream_t stream = nullptr) {
|
|
return run(stream);
|
|
}
|
|
|
|
/// Runs the kernel using initialized state.
|
|
Status operator()(
|
|
Arguments const &args,
|
|
void *workspace = nullptr,
|
|
cudaStream_t stream = nullptr) {
|
|
|
|
Status status = initialize(args, workspace, stream);
|
|
|
|
if (status == Status::kSuccess) {
|
|
status = run(stream);
|
|
}
|
|
|
|
return status;
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace device
|
|
} // namespace gemm
|
|
} // namespace cutlass
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|