510 lines
18 KiB
C++
510 lines
18 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
* provided that the following conditions are met:
|
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
|
* conditions and the following disclaimer.
|
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
* conditions and the following disclaimer in the documentation and/or other materials
|
|
* provided with the distribution.
|
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
* to endorse or promote products derived from this software without specific prior written
|
|
* permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*! \file
|
|
\brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/array.h"
|
|
#include "cutlass/aligned_buffer.h"
|
|
#include "cutlass/numeric_conversion.h"
|
|
|
|
#include "cutlass/numeric_types.h"
|
|
#include "cutlass/matrix_shape.h"
|
|
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
|
|
|
#include "threadblock/b2b_mma_base.h"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace gemm {
|
|
namespace threadblock {
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
|
template<int a>
|
|
struct chk_val {
|
|
static_assert(a==0, "check value");
|
|
};
|
|
|
|
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
|
template <
|
|
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
|
typename Shape0_,
|
|
/// Iterates over tiles of A operand in global memory
|
|
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
|
typename IteratorA0_,
|
|
/// Iterates over tiles of A operand in shared memory
|
|
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
|
typename SmemIteratorA0_,
|
|
/// Iterates over tiles of B operand in global memory
|
|
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
|
typename IteratorB0_,
|
|
/// Iterates over tiles of B operand in shared memory
|
|
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
|
typename SmemIteratorB0_,
|
|
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
|
typename Shape1_,
|
|
/// Iterates over the intermediate accumulator tile
|
|
// (concept::MmaTensorOpFragmentIterator)
|
|
typename FragmentIteratorA1_,
|
|
/// Iterates over tiles of B operand in global memory
|
|
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
|
typename IteratorB1_,
|
|
/// Iterates over tiles of B operand in shared memory
|
|
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
|
typename SmemIteratorB1_,
|
|
/// Data type of accumulator matrix
|
|
typename ElementC_,
|
|
/// Data type of accumulator matrix
|
|
typename LayoutC_,
|
|
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
|
typename OutputOp_,
|
|
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
|
typename Policy0_,
|
|
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
|
typename Policy1_,
|
|
/// Transformation applied to A0 operand
|
|
typename TransformA0_ = NumericArrayConverter<
|
|
typename SmemIteratorA0_::Element,
|
|
typename IteratorA0_::Element,
|
|
IteratorA0_::Fragment::kElements>,
|
|
///
|
|
/// Transformation applied to B0 operand
|
|
typename TransformB0_ = NumericArrayConverter<
|
|
typename SmemIteratorB0_::Element,
|
|
typename IteratorB0_::Element,
|
|
IteratorB0_::Fragment::kElements>,
|
|
///
|
|
/// Transformation applied to B1 operand
|
|
typename TransformB1_ = NumericArrayConverter<
|
|
typename SmemIteratorB1_::Element,
|
|
typename IteratorB1_::Element,
|
|
IteratorB1_::Fragment::kElements>,
|
|
/// Used for partial specialization
|
|
typename Enable = bool
|
|
>
|
|
class B2bMmaPipelined : public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
|
public:
|
|
|
|
///< Base class
|
|
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2>;
|
|
|
|
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
|
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
|
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
|
using Policy0 = Policy0_; ///< Policy describing tuning details
|
|
|
|
using SmemIteratorA0 = SmemIteratorA0_;
|
|
using SmemIteratorB0 = SmemIteratorB0_;
|
|
|
|
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
|
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile
|
|
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
|
using Policy1 = Policy1_; ///< Policy describing tuning details
|
|
|
|
using SmemIteratorB1 = SmemIteratorB1_;
|
|
|
|
|
|
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
|
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
|
|
|
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
|
|
|
using TransformA0 = TransformA0_;
|
|
using TransformB0 = TransformB0_;
|
|
using TransformB1 = TransformB1_;
|
|
|
|
//
|
|
// Dependent types
|
|
//
|
|
|
|
/// Fragment of operand A loaded from global memory
|
|
using FragmentA0 = typename IteratorA0::Fragment;
|
|
|
|
/// Fragment of operand B loaded from global memory
|
|
using FragmentB0 = typename IteratorB0::Fragment;
|
|
|
|
/// Fragment of accumulator tile
|
|
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
|
|
|
/// Warp-level Mma
|
|
using Operator0 = typename Policy0::Operator;
|
|
|
|
/// Fragment of operand B loaded from global memory
|
|
using FragmentB1 = typename IteratorB1::Fragment;
|
|
|
|
/// Fragment of accumulator tile
|
|
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
|
|
|
/// Warp-level Mma
|
|
using Operator1 = typename Policy1::Operator;
|
|
|
|
/// Obtain the arch tag from the warp-level operator
|
|
using ArchTag = typename Policy0::Operator::ArchTag;
|
|
|
|
/// Complex transform on A0 operand
|
|
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
|
|
|
/// Complex transform on B0 operand
|
|
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
|
|
|
/// Complex transform on B1 operand
|
|
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
|
|
|
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
|
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
|
|
|
private:
|
|
|
|
using WarpFragmentA0 = typename Operator0::FragmentA;
|
|
using WarpFragmentB0 = typename Operator0::FragmentB;
|
|
/// Warp Fragment of operand A1 loaded from accmulator tile
|
|
using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
|
|
using WarpFragmentB1 = typename Operator1::FragmentB;
|
|
|
|
protected:
|
|
|
|
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
|
SmemIteratorA0 smem_iterator_A_;
|
|
|
|
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
|
SmemIteratorB0 smem_iterator_B0_;
|
|
|
|
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
|
SmemIteratorB1 smem_iterator_B1_;
|
|
|
|
public:
|
|
|
|
/// Construct from tensor references
|
|
CUTLASS_DEVICE
|
|
B2bMmaPipelined(
|
|
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
|
int thread_idx, ///< ID within the threadblock
|
|
int warp_idx, ///< ID of warp
|
|
int lane_idx ///< ID of each thread within a warp
|
|
):
|
|
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
|
smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
|
|
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
|
|
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) {
|
|
|
|
|
|
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
|
|
// _m: the warp's position within the threadblock along the M dimension
|
|
// _n: the warp's position within the threadblock along the N dimension
|
|
// _k: the warp's position within the threadblock along the K dimension
|
|
|
|
//These should stay the same across different GEMM layers
|
|
int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
|
int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
|
|
|
int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
|
|
int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
|
|
|
|
//These may change across different GEMM layers
|
|
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k;
|
|
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k;
|
|
|
|
// Add per-warp offsets in units of warp-level tiles
|
|
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0});
|
|
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n});
|
|
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n});
|
|
}
|
|
|
|
/// Perform a threadblock-scoped matrix multiply-accumulate
|
|
CUTLASS_DEVICE
|
|
void operator()(
|
|
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
|
FragmentC1 &accum, ///< destination accumulator tile
|
|
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
|
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
|
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
|
FragmentC0 const &src_accum, ///< source accumualtor tile
|
|
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
|
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
|
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
|
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
|
|
|
//
|
|
// Prologue
|
|
//
|
|
|
|
// Perform accumulation in the 'd' output operand
|
|
FragmentC0 accum0 = src_accum;
|
|
|
|
FragmentA0 tb_frag_A;
|
|
FragmentB0 tb_frag_B0;
|
|
|
|
tb_frag_A.clear();
|
|
tb_frag_B0.clear();
|
|
|
|
// The last kblock is loaded in the prolog
|
|
iterator_A.load(tb_frag_A);
|
|
iterator_B0.load(tb_frag_B0);
|
|
|
|
++iterator_A;
|
|
++iterator_B0;
|
|
|
|
this->smem_iterator_A_.store(tb_frag_A);
|
|
this->smem_iterator_B0_.store(tb_frag_B0);
|
|
|
|
++this->smem_iterator_A_;
|
|
++this->smem_iterator_B0_;
|
|
|
|
__syncthreads();
|
|
|
|
// Pair of fragments used to overlap shared memory loads and math instructions
|
|
WarpFragmentA0 warp_frag_A0[2];
|
|
WarpFragmentB0 warp_frag_B0[2];
|
|
|
|
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
|
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
|
|
|
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
|
|
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
|
|
|
|
++this->warp_tile_iterator_A0_;
|
|
++this->warp_tile_iterator_B0_;
|
|
|
|
Operator0 warp_mma0;
|
|
|
|
int smem_write_stage_idx = 1;
|
|
|
|
// Avoid reading out of bounds
|
|
if (gemm_k_iterations_0 <= 1) {
|
|
iterator_A.clear_mask();
|
|
iterator_B0.clear_mask();
|
|
}
|
|
|
|
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
|
// shared memory loads (which have the tighest latency requirement).
|
|
iterator_A.load(tb_frag_A);
|
|
|
|
//
|
|
// Mainloop
|
|
//
|
|
|
|
// Note: The main loop does not support Base::WarpGemmIterations == 2.
|
|
CUTLASS_GEMM_LOOP
|
|
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
|
|
|
//
|
|
// Loop over GEMM K dimension
|
|
//
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
|
|
|
|
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
|
// as the case may be.
|
|
|
|
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
|
|
|
// Write fragments to shared memory
|
|
this->smem_iterator_A_.store(tb_frag_A);
|
|
|
|
this->smem_iterator_B0_.store(tb_frag_B0);
|
|
|
|
__syncthreads();
|
|
|
|
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
|
// shared memory loads (which have the tighest latency requirement).
|
|
iterator_A.load(tb_frag_A);
|
|
|
|
++this->smem_iterator_B0_;
|
|
++this->smem_iterator_A_;
|
|
|
|
|
|
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
|
if (smem_write_stage_idx == 1) {
|
|
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
|
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
|
}
|
|
else {
|
|
this->warp_tile_iterator_A0_.add_tile_offset(
|
|
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
|
|
this->warp_tile_iterator_B0_.add_tile_offset(
|
|
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
|
|
0});
|
|
}
|
|
|
|
smem_write_stage_idx ^= 1;
|
|
}
|
|
|
|
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
|
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
|
|
|
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
|
|
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
|
|
|
|
++this->warp_tile_iterator_A0_;
|
|
++this->warp_tile_iterator_B0_;
|
|
|
|
if (warp_mma_k == 0) {
|
|
|
|
iterator_B0.load(tb_frag_B0);
|
|
|
|
++iterator_A;
|
|
++iterator_B0;
|
|
|
|
// Avoid reading out of bounds if this was the last loop iteration
|
|
if (gemm_k_iterations_0 <= 2) {
|
|
iterator_A.clear_mask();
|
|
iterator_B0.clear_mask();
|
|
}
|
|
}
|
|
|
|
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);
|
|
}
|
|
}
|
|
|
|
//2nd Gemm
|
|
|
|
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
|
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
|
|
|
//
|
|
// Prologue
|
|
//
|
|
|
|
FragmentB1 tb_frag_B1;
|
|
|
|
tb_frag_B1.clear();
|
|
|
|
// The last kblock is loaded in the prolog
|
|
iterator_B1.load(tb_frag_B1);
|
|
|
|
++iterator_B1;
|
|
|
|
this->smem_iterator_B1_.store(tb_frag_B1);
|
|
|
|
++this->smem_iterator_B1_;
|
|
|
|
__syncthreads();
|
|
|
|
// Pair of fragments used to overlap shared memory loads and math instructions
|
|
WarpFragmentA1 warp_frag_A1[2];
|
|
WarpFragmentB1 warp_frag_B1[2];
|
|
|
|
//warp_tile_iterator_A1_.set_kgroup_index(0);
|
|
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
|
|
|
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
|
|
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
|
|
|
++warp_tile_iterator_A1_;
|
|
++this->warp_tile_iterator_B1_;
|
|
|
|
Operator1 warp_mma1;
|
|
|
|
smem_write_stage_idx = 1;
|
|
|
|
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
|
|
|
// Avoid reading out of bounds
|
|
if (gemm_k_iterations_1 <= 1) {
|
|
iterator_B1.clear_mask();
|
|
}
|
|
|
|
//
|
|
// Mainloop
|
|
//
|
|
|
|
// Note: The main loop does not support Base::WarpGemmIterations == 2.
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
|
|
|
//
|
|
// Loop over GEMM K dimension
|
|
//
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
|
|
|
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
|
// as the case may be.
|
|
|
|
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
|
|
|
// Write fragments to shared memory
|
|
|
|
this->smem_iterator_B1_.store(tb_frag_B1);
|
|
|
|
__syncthreads();
|
|
++smem_iterator_B1_;
|
|
|
|
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
|
if (smem_write_stage_idx == 1) {
|
|
smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
|
}
|
|
else {
|
|
this->warp_tile_iterator_B1_.add_tile_offset(
|
|
{-Base::kStages * Policy1::kPartitionsK *
|
|
Base::kWarpGemmIterations1,
|
|
0});
|
|
}
|
|
|
|
smem_write_stage_idx ^= 1;
|
|
}
|
|
|
|
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
|
|
|
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
|
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
|
|
|
|
|
++warp_tile_iterator_A1_;
|
|
++this->warp_tile_iterator_B1_;
|
|
|
|
if (warp_mma_k == 0) {
|
|
|
|
iterator_B1.load(tb_frag_B1);
|
|
++iterator_B1;
|
|
|
|
|
|
// Avoid reading out of bounds if this was the last loop iteration
|
|
if (gemm_k_iterations_1 <= 2) {
|
|
iterator_B1.clear_mask();
|
|
}
|
|
}
|
|
|
|
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
|
|
}
|
|
}
|
|
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace threadblock
|
|
} // namespace gemm
|
|
} // namespace cutlass
|