322 lines
10 KiB
C++
322 lines
10 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
/*! \file
|
|
\brief Gemm kernel with an epilogue defined under the epilogue visitor concept
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/gemm/kernel/gemm_universal.h"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace gemm {
|
|
namespace kernel {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Gemm that compute the epilogue visitor functor
|
|
template <
|
|
typename Mma, ///! Threadblock-scoped matrix multiply-accumulate
|
|
typename Epilogue, ///! Epilogue
|
|
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
|
>
|
|
class GemmWithEpilogueVisitor: GemmUniversal<Mma,Epilogue, ThreadblockSwizzle_> {
|
|
public:
|
|
|
|
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
|
|
|
using Base = GemmUniversal<Mma,Epilogue, ThreadblockSwizzle>;
|
|
using Base::Base;
|
|
|
|
using FusionCallbacks = typename Epilogue::FusionCallbacks;
|
|
|
|
using ElementA = typename Base::ElementA;
|
|
using LayoutA = typename Base::LayoutA;
|
|
using ElementB = typename Base::ElementB;
|
|
using LayoutB = typename Base::LayoutB;
|
|
using ElementC = typename Base::ElementC;
|
|
using LayoutC = typename Base::LayoutC;
|
|
|
|
using ThreadblockShape = typename Mma::Shape;
|
|
|
|
//
|
|
// Structures
|
|
//
|
|
|
|
using SharedStorage = typename Base::SharedStorage;
|
|
using Arguments = typename Base::Arguments;
|
|
|
|
//
|
|
// Structure for precomputing values in host memory and passing to kernels
|
|
//
|
|
|
|
/// Parameters structure
|
|
struct Params : UniversalParamsBase<
|
|
ThreadblockSwizzle,
|
|
ThreadblockShape,
|
|
ElementA,
|
|
ElementB,
|
|
ElementC,
|
|
LayoutA,
|
|
LayoutB>
|
|
{
|
|
using ParamsBase = UniversalParamsBase<
|
|
ThreadblockSwizzle,
|
|
ThreadblockShape,
|
|
ElementA,
|
|
ElementB,
|
|
ElementC,
|
|
LayoutA,
|
|
LayoutB>;
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
cute::Shape<int32_t,int32_t,int32_t> problem_shape;
|
|
|
|
typename Mma::IteratorA::Params params_A;
|
|
typename Mma::IteratorB::Params params_B;
|
|
typename FusionCallbacks::Params output_op;
|
|
|
|
void * ptr_A;
|
|
void * ptr_B;
|
|
|
|
int64_t batch_stride_A;
|
|
int64_t batch_stride_B;
|
|
|
|
int * ptr_gather_A_indices;
|
|
int * ptr_gather_B_indices;
|
|
|
|
//
|
|
// Host dispatch API
|
|
//
|
|
|
|
/// Default constructor
|
|
Params() = default;
|
|
|
|
/// Constructor
|
|
Params(
|
|
Arguments const &args, /// GEMM application arguments
|
|
int device_sms, /// Number of SMs on the device
|
|
int sm_occupancy) /// Kernel SM occupancy (in thread blocks)
|
|
:
|
|
ParamsBase(args, device_sms, sm_occupancy),
|
|
params_A(args.lda ? make_Coord_with_padding<LayoutA::kStrideRank>(args.lda) : args.stride_a),
|
|
params_B(args.ldb ? make_Coord_with_padding<LayoutB::kStrideRank>(args.ldb) : args.stride_b),
|
|
output_op(FusionCallbacks::to_underlying_arguments(args.problem_size, args.epilogue, nullptr /*workspace*/)),
|
|
problem_shape({args.problem_size.m(), args.problem_size.n(), args.batch_count}),
|
|
ptr_A(const_cast<void *>(args.ptr_A)),
|
|
ptr_B(const_cast<void *>(args.ptr_B)),
|
|
batch_stride_A(args.batch_stride_A),
|
|
batch_stride_B(args.batch_stride_B),
|
|
ptr_gather_A_indices(const_cast<int *>(args.ptr_gather_A_indices)),
|
|
ptr_gather_B_indices(const_cast<int *>(args.ptr_gather_B_indices))
|
|
{
|
|
// Raise error on unsupported modes
|
|
assert(args.mode != GemmUniversalMode::kGemmSplitKParallel && "Sm80 EVT does not support SplitKParallel.");
|
|
assert(!(args.mode == GemmUniversalMode::kGemm && this->grid_tiled_shape.k() > 1 )
|
|
&& "Sm80 EVT does not support SplitKSerial.");
|
|
assert(args.mode != GemmUniversalMode::kArray && "Sm80 EVT does not support Array Gemm.");
|
|
}
|
|
|
|
/// Lightweight update given a subset of arguments.
|
|
void update(Arguments const &args)
|
|
{
|
|
CUTLASS_TRACE_HOST("GemmUniversalwithVisitor::Params::update()");
|
|
|
|
// Update input pointers
|
|
ptr_A = const_cast<void *>(args.ptr_A);
|
|
ptr_B = const_cast<void *>(args.ptr_B);
|
|
|
|
batch_stride_A = args.batch_stride_A;
|
|
batch_stride_B = args.batch_stride_B;
|
|
this->batch_stride_D = args.batch_stride_D;
|
|
|
|
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
|
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
|
|
|
output_op = FusionCallbacks::to_underlying_arguments(args.problem_size, args.epilogue, nullptr /*workspace*/);
|
|
problem_shape = make_shape(args.problem_size.m(), args.problem_size.n(), args.batch_count);
|
|
}
|
|
};
|
|
|
|
public:
|
|
|
|
//
|
|
// Device-only API
|
|
//
|
|
|
|
// Factory invocation
|
|
CUTLASS_DEVICE
|
|
static void invoke(
|
|
Params const ¶ms,
|
|
SharedStorage &shared_storage)
|
|
{
|
|
GemmWithEpilogueVisitor op;
|
|
op(params, shared_storage);
|
|
}
|
|
|
|
|
|
/// Executes one GEMM
|
|
CUTLASS_DEVICE
|
|
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
|
ThreadblockSwizzle threadblock_swizzle;
|
|
run_with_swizzle(params, shared_storage, threadblock_swizzle);
|
|
}
|
|
|
|
/// Executes one GEMM with an externally-provided swizzling function
|
|
CUTLASS_DEVICE
|
|
void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
|
|
|
|
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
|
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
|
|
|
// Early exit if CTA is out of range
|
|
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
|
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
|
|
|
return;
|
|
}
|
|
|
|
int offset_k = 0;
|
|
int problem_size_k = params.problem_size.k();
|
|
|
|
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
|
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
|
|
|
//
|
|
// Fetch pointers based on mode.
|
|
//
|
|
if (params.mode == GemmUniversalMode::kGemm) {
|
|
|
|
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
|
|
|
|
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
|
}
|
|
|
|
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
|
}
|
|
else if (params.mode == GemmUniversalMode::kBatched) {
|
|
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
|
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// Compute initial location in logical coordinates
|
|
cutlass::MatrixCoord tb_offset_A{
|
|
threadblock_tile_offset.m() * Mma::Shape::kM,
|
|
offset_k,
|
|
};
|
|
|
|
cutlass::MatrixCoord tb_offset_B{
|
|
offset_k,
|
|
threadblock_tile_offset.n() * Mma::Shape::kN
|
|
};
|
|
|
|
// Compute position within threadblock
|
|
int thread_idx = threadIdx.x;
|
|
|
|
// Construct iterators to A and B operands
|
|
typename Mma::IteratorA iterator_A(
|
|
params.params_A,
|
|
ptr_A,
|
|
{params.problem_size.m(), problem_size_k},
|
|
thread_idx,
|
|
tb_offset_A,
|
|
params.ptr_gather_A_indices);
|
|
|
|
typename Mma::IteratorB iterator_B(
|
|
params.params_B,
|
|
ptr_B,
|
|
{problem_size_k, params.problem_size.n()},
|
|
thread_idx,
|
|
tb_offset_B,
|
|
params.ptr_gather_B_indices);
|
|
|
|
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
|
// is compiled as warp-uniform.
|
|
int warp_idx = canonical_warp_idx_sync();
|
|
|
|
int lane_idx = threadIdx.x % 32;
|
|
|
|
//
|
|
// Main loop
|
|
//
|
|
|
|
// Construct thread-scoped matrix multiply
|
|
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
|
|
|
typename Mma::FragmentC accumulators;
|
|
|
|
accumulators.clear();
|
|
|
|
// Compute threadblock-scoped matrix multiply-add
|
|
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
|
|
|
// Compute threadblock-scoped matrix multiply-add
|
|
mma(
|
|
gemm_k_iterations,
|
|
accumulators,
|
|
iterator_A,
|
|
iterator_B,
|
|
accumulators);
|
|
|
|
//
|
|
// Epilogue
|
|
//
|
|
|
|
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
|
|
|
Epilogue epilogue(
|
|
params.output_op,
|
|
shared_storage.epilogue,
|
|
thread_idx,
|
|
warp_idx,
|
|
lane_idx);
|
|
|
|
// Execute the epilogue operator to update the destination tensor.
|
|
epilogue(accumulators, threadblock_tile_offset, params.problem_shape, thread_idx);
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace kernel
|
|
} // namespace gemm
|
|
} // namespace cutlass
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|