cutlass/examples/13_fused_two_gemms/threadblock/b2b_mma_base.h

231 lines
7.7 KiB
C
Raw Normal View History

/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape0_,
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape1_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy0_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy1_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class B2bMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape0 = Shape0_;
using Shape1 = Shape1_;
///< Policy describing tuning details
using Policy0 = Policy0_;
using Policy1 = Policy1_;
//
// Dependent types
//
/// Warp-level Mma
using Operator0 = typename Policy0::Operator;
using Operator1 = typename Policy1::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm0 = typename Policy0::Operator::Shape;
using WarpGemm1 = typename Policy1::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount0 = GemmShape<Shape0::kM / WarpGemm0::kM,
Shape0::kN / WarpGemm0::kN,
Shape0::kK / WarpGemm0::kK>;
using WarpCount1 = GemmShape<Shape1::kM / WarpGemm1::kM,
Shape1::kN / WarpGemm1::kN,
Shape1::kK / WarpGemm1::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations0 =
(WarpGemm0::kK / Operator0::Policy::MmaShape::kK);
static int const kWarpGemmIterations1 =
(WarpGemm1::kK / Operator1::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
template<
typename Shape_,
typename Policy_
>
class SharedStorage {
public:
//
// Type definitions
//
using Shape = Shape_;
using Policy = Policy_;
using Operator = typename Policy::Operator;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages +
Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB =
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
};
using SharedStorage0 = SharedStorage<Shape0, Policy0>;
using SharedStorage1 = SharedStorage<Shape1, Policy1>;
union B2bMmaSharedStorage {
SharedStorage0 sharedStorage0;
SharedStorage1 sharedStorage1;
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A0 operand from shared memory
typename Operator0::IteratorA warp_tile_iterator_A0_;
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
typename Operator0::IteratorB warp_tile_iterator_B0_;
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
typename Operator1::IteratorB warp_tile_iterator_B1_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
B2bMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
B2bMmaSharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),
warp_tile_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), lane_idx),
warp_tile_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), lane_idx) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////