cutlass/include/cutlass/gemm/kernel/rank_2k_grouped.h
Vijay Thakkar 277bd6e537
CUTLASS 3.0.0 (#786)
* CUTLASS 3.0.0
2023-01-23 20:55:28 -05:00

705 lines
22 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Grouped Rank2K kernel.
*/
#pragma once
#include "cutlass/blas3.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/rank_2k_transpose_operands.h"
#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma1_, ///! Threadblock-scoped matrix multiply-accumulate (A*B^T)
typename Mma2_, ///! Threadblock-scoped matrix multiply-accumulate (B*A^T)
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
ComplexTransform OriginalTransformA_, ///! Public-facing transformation on A
ComplexTransform OriginalTransformB_, ///! Public-facing transformation on B
FillMode FillModeC_, ///! Fill Mode for C (kLower or kUpper)
BlasMode BlasMode_, ///! Blas3 computation mode
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
bool Transposed = false
>
struct Rank2KGrouped {
public:
using Mma1 = Mma1_;
using Mma2 = Mma2_;
static_assert(platform::is_same<typename Mma1::LayoutC, cutlass::layout::RowMajor>::value &&
platform::is_same<typename Mma2::LayoutC, cutlass::layout::RowMajor>::value,
"Kernel-level grouped Rank2K requires that LayoutC be row major.");
// Define generic Mma for usecases that use Kernel::Mma
using Mma = Mma1_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = Transposed;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion to reflect the original layout,
// fill mode, etc. passed in.
//
// Recall that a Rank2K operation performs (A x BT) + (B x AT)
// This is performed via:
// Mma1 = (A x BT)
// Mma2 = (B x AT)
//
// However, if C needs to be transposed, then this is changed to the following:
// Mma1 = (B x AT)
// Mma2 = (A x BT)
//
// The transformation above is achieved by swapping the Layouts/Elements/Transforms/etc.
// of A and B as they are passed into the instantiations of Mma1 and Mma2.
//
// Now, given access to only Mma1 and Mma2, as well as whether a transposition has occurred,
// we wish to retrieve the original Layouts/Elements/etc. for A and B that were passed into
// the device-level call.
//
// The logic to do this (which is made clearer by referencing the above instantiations) is as follows:
// LayoutA = kTransposed ? Mma2::LayoutA : Mma1::LayoutA
// LayoutB = kTransposed ? Mma1::LayoutA : Mma2::LayoutA
//
// We achieve this swapping by passing Mma1::*A and Mma2::*B to Rank2KMapArguments:
using MapArgumentsA = kernel::detail::Rank2KMapArguments<
typename Mma1::IteratorA::Element,
typename Mma1::IteratorA::Layout,
Mma1::kTransformA,
Mma1::IteratorA::AccessType::kElements,
typename Mma2::IteratorA::Element,
typename Mma2::IteratorA::Layout,
Mma2::kTransformA,
Mma2::IteratorA::AccessType::kElements,
typename Mma1::LayoutC,
FillModeC_,
kTransposed
>;
using ElementA = typename MapArgumentsA::ElementA;
using LayoutA = typename MapArgumentsA::LayoutA;
static int const kAlignmentA = MapArgumentsA::kAlignmentA;
using MapArgumentsB = kernel::detail::Rank2KMapArguments<
typename Mma2::IteratorA::Element,
typename Mma2::IteratorA::Layout,
Mma2::kTransformA,
Mma2::IteratorA::AccessType::kElements,
typename Mma1::IteratorA::Element,
typename Mma1::IteratorA::Layout,
Mma1::kTransformA,
Mma1::IteratorA::AccessType::kElements,
typename Mma2::LayoutC,
FillModeC_,
kTransposed
>;
using ElementB = typename MapArgumentsB::ElementA;
using LayoutB = typename MapArgumentsB::LayoutA;
static int const kAlignmentB = MapArgumentsB::kAlignmentA;
// Use the user-provided TransformA and TransformB, rather than those
// resulting from MapArguments, because Mma1 and Mma2 may have different
// complex transforms than those passed in by the user.
// (See kernel/rank_2k_complex.h for an example of this)
static cutlass::ComplexTransform const kTransformA = OriginalTransformA_;
static cutlass::ComplexTransform const kTransformB = OriginalTransformB_;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArgumentsA::LayoutC;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
static FillMode const kFillModeC = MapArgumentsA::kFillModeC;
// Common type definitions for Mma1 and Mma2
using Operator = typename Mma1::Operator;
using OperatorClass = typename Mma1::Operator::OperatorClass;
using ThreadblockShape = typename Mma1::Shape;
using WarpShape = typename Mma1::Operator::Shape;
using InstructionShape = typename Mma1::Policy::Operator::InstructionShape;
using ArchTag = typename Mma1::ArchTag;
static int const kStages = Mma1::kStages;
static BlasMode const kBlasMode = BlasMode_;
private:
static FillMode const kInternalFillModeC = FillModeC_;
public:
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma1::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor = Rank2KGroupedProblemVisitor<
ThreadblockShape,
kGroupScheduleMode,
kThreadCount,
kThreadCount,
kInternalFillModeC>;
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord *problem_sizes;
int problem_count;
int threadblock_count;
typename EpilogueOutputOp::Params epilogue;
ElementA ** ptr_A;
ElementB ** ptr_B;
ElementC ** ptr_C;
ElementC ** ptr_D;
typename LayoutA::Stride::LongIndex *lda;
typename LayoutB::Stride::LongIndex *ldb;
typename LayoutC::Stride::LongIndex *ldc;
typename LayoutC::Stride::LongIndex *ldd;
// Only used by device-level operator
GemmCoord *host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments():
mode(GemmUniversalMode::kGemm),
problem_count(0),
threadblock_count(0),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
lda(nullptr),
ldb(nullptr),
ldc(nullptr),
ldd(nullptr),
host_problem_sizes(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(
GemmUniversalMode mode,
GemmCoord *problem_sizes,
int problem_count,
int threadblock_count,
typename EpilogueOutputOp::Params epilogue,
ElementA ** ptr_A,
ElementB ** ptr_B,
ElementC ** ptr_C,
ElementC ** ptr_D,
typename LayoutA::Stride::LongIndex *lda,
typename LayoutB::Stride::LongIndex *ldb,
typename LayoutC::Stride::LongIndex *ldc,
typename LayoutC::Stride::LongIndex *ldd,
GemmCoord *host_problem_sizes=nullptr
):
mode(mode),
problem_sizes(problem_sizes),
problem_count(problem_count),
threadblock_count(threadblock_count),
epilogue(epilogue),
ptr_A(ptr_A),
ptr_B(ptr_B),
ptr_C(ptr_C),
ptr_D(ptr_D),
lda(lda),
ldb(ldb),
ldc(ldc),
ldd(ldd),
host_problem_sizes(host_problem_sizes)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
GemmUniversalMode mode;
int batch_count;
ElementA ** ptr_A;
ElementB ** ptr_B;
ElementC ** ptr_C;
ElementC ** ptr_D;
typename LayoutA::Stride::LongIndex *lda;
typename LayoutB::Stride::LongIndex *ldb;
typename LayoutC::Stride::LongIndex *ldc;
typename LayoutC::Stride::LongIndex *ldd;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
lda(nullptr),
ldb(nullptr),
ldc(nullptr),
ldd(nullptr)
{ }
CUTLASS_HOST_DEVICE
Params(Arguments const &args, void *workspace = nullptr, int tile_count = 0):
problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
output_op(args.epilogue),
ptr_A(args.ptr_A),
ptr_B(args.ptr_B),
ptr_C(args.ptr_C),
ptr_D(args.ptr_D),
lda(args.lda),
ldb(args.ldb),
ldc(args.ldc),
ldd(args.ldd)
{
}
CUTLASS_HOST_DEVICE
void update(
Arguments const &args,
void *workspace = nullptr,
int tile_count = 0) {
problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
}
};
/// Shared memory storage structure
struct SharedStorage {
union {
typename Mma1::SharedStorage mma1_main_loop;
typename Mma2::SharedStorage mma2_main_loop;
typename Epilogue::SharedStorage epilogue;
} kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
Rank2KGrouped() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
//
// Problem visitor.
//
ProblemVisitor problem_visitor(
params.problem_visitor,
shared_storage.problem_visitor,
blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile()) {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_tile_offset = problem_visitor.threadblock_offset(threadblock_idx);
//
// Perform checks to determine whether the results of this threadblock will be needed.
// An example of an unneeded threadblock is one that is assigned to compute in the upper
// portion of a Rank2K kernel filled with mode kLower.
//
// TODO: Consider pushing these checks into ProblemVisitor to avoid spuriously
// returning from `next_tile()`.
//
// Early exit if threadblock is out of range
if (grid_shape.m() <= threadblock_tile_offset.m() ||
grid_shape.n() <= threadblock_tile_offset.n()) {
// Next tile
problem_visitor.advance(gridDim.x);
continue;
}
// Skip this tile if Fill Mode is Lower and
// if the entire tile is above the main diagonal (bottom-left corner is at or above the diagonal)
if (kInternalFillModeC == cutlass::FillMode::kLower &&
(threadblock_tile_offset.m() + 1) * Mma1::Shape::kM <= threadblock_tile_offset.n() * Mma1::Shape::kN) {
// Next tile
problem_visitor.advance(gridDim.x);
continue;
}
// Skip this tile if Fill Mode is Upper and
// if the entire tile is below the main diagonal (top-right corner is at or below the diagonal)
if (kInternalFillModeC == cutlass::FillMode::kUpper &&
threadblock_tile_offset.m() * Mma1::Shape::kM >= (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) {
// Next tile
problem_visitor.advance(gridDim.x);
continue;
}
bool tile_on_diagonal = false;
// Mark tiles that are being crossed by the main diagonal
// (top-right and bottom-left corners are on either side of the diagonal)
if ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM > threadblock_tile_offset.n() * Mma1::Shape::kN
&& threadblock_tile_offset.m() * Mma1::Shape::kM < (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) {
tile_on_diagonal = true;
}
int offset_k = 0;
int problem_size_k = problem_size.k();
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm ||
params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < grid_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * problem_size.k();
}
offset_k = threadblock_tile_offset.k() * problem_size.k();
}
ElementA *ptr_A = reinterpret_cast<ElementA *>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
typename LayoutA::Stride::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
ElementB *ptr_B = reinterpret_cast<ElementB *>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
typename LayoutB::Stride::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_MxK{
threadblock_tile_offset.m() * Mma1::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_KxN{
offset_k,
threadblock_tile_offset.n() * Mma1::Shape::kN
};
// Assume identity swizzle
MatrixCoord tb_offset(
threadblock_tile_offset.m() * Mma1::Shape::kM,
threadblock_tile_offset.n() * Mma1::Shape::kN
);
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands for Mma1
typename Mma1::IteratorA iterator_A(
Mma1::IteratorA::Params(ldm_A),
ptr_A,
{problem_size.m(), problem_size_k},
thread_idx,
tb_offset_MxK);
typename Mma1::IteratorB iterator_BT(
Mma1::IteratorB::Params(ldm_B),
ptr_B,
{problem_size_k, problem_size.n()},
thread_idx,
tb_offset_KxN);
// Construct iterators to A and B operands for Mma2
typename Mma2::IteratorA iterator_B(
Mma2::IteratorA::Params(ldm_B),
ptr_B,
{problem_size.m(), problem_size_k},
thread_idx,
tb_offset_MxK);
typename Mma2::IteratorB iterator_AT(
Mma2::IteratorB::Params(ldm_A),
ptr_A,
{problem_size_k, problem_size.n()},
thread_idx,
tb_offset_KxN);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = canonical_warp_idx();
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply for Mma1 (A x BT)
Mma1 mma1(shared_storage.kernel.mma1_main_loop, thread_idx, warp_idx, lane_idx);
// Construct thread-scoped matrix multiply for Mma2 (B x AT)
Mma2 mma2(shared_storage.kernel.mma2_main_loop, thread_idx, warp_idx, lane_idx);
typename Mma1::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add (A x BT)
mma1(
gemm_k_iterations,
accumulators,
iterator_A,
iterator_BT,
accumulators);
// HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK.
if (kBlasMode == BlasMode::kHermitian) {
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_shape.m();
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C[problem_idx]);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D[problem_idx]);
// If TB not on diagonal, FillMode doesn't apply.
FillMode kFillModeTB = tile_on_diagonal ? kInternalFillModeC : FillMode::kNone;
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]),
ptr_C,
problem_size.mn(),
thread_idx,
tb_offset,
kFillModeTB
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]),
ptr_D,
problem_size.mn(),
thread_idx,
tb_offset,
kFillModeTB
);
Epilogue epilogue(
shared_storage.kernel.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
__syncthreads();
accumulators.clear();
}
// Compute threadblock-scoped matrix multiply-add (B x AT)
mma2(
gemm_k_iterations,
accumulators,
iterator_B,
iterator_AT,
accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
/* Needed for HER2K where the second HERK is multiplied by conj(alpha) */
typename EpilogueOutputOp::Params second_her2k_params(conj(params.output_op.alpha), 1);
EpilogueOutputOp output_op_her2k(second_her2k_params);
//
// Masked tile iterators constructed from members
//
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_shape.m();
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C[problem_idx]);
// HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK.
if (kBlasMode == BlasMode::kHermitian) {
ptr_C = static_cast<ElementC *>(params.ptr_D[problem_idx]);
}
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D[problem_idx]);
// If TB not on diagonal, FillMode doesn't apply.
FillMode kFillModeTB = tile_on_diagonal ? kInternalFillModeC : FillMode::kNone;
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]),
ptr_C,
problem_size.mn(),
thread_idx,
tb_offset,
kFillModeTB
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]),
ptr_D,
problem_size.mn(),
thread_idx,
tb_offset,
kFillModeTB
);
Epilogue epilogue(
shared_storage.kernel.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Execute the epilogue operator to update the destination tensor.
if (kBlasMode == BlasMode::kSymmetric) {
epilogue(
output_op,
iterator_D,
accumulators,
iterator_C);
} else {
epilogue(
output_op_her2k,
iterator_D,
accumulators,
iterator_C);
}
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////