Stream-K with broadcast (#892)
* [WIP] GEMM StreamK w/ Fused Epilogue * Adds Gemm Streamk with Fused Epilogue kernel level struct. * Mostly based on Gemm with Fused Epilogue, * Requires a new epilogue * Work in progress * [WIP] StreamK support for GemmUniversalWithBroadcast * Just based off of how StreamK is allowed in GemmUniversal * Untested and a work in progress * Minor fixes * [WIP] It compiles! It is almost certainly incorrect, but we're past getting the templates to match, so checkpointing. * Correction to reference kernel * Fix typo * Added MSE measurement * Switch back to reference kernel + host for loop Still WIP. Now we're getting even a larger MSE, but it's both on basic Split-K and Stream-K. * Fix typos * Fix broadcast vector + requested changes * Comment typo * Small int option and more * Fix incorrect condition on source needed * Requested changes * I think I got it? * Bias vector should be stride 0 * Two source added! * Typos * Merge examples * Bring back vector row offset Just to ensure consistency with universal gemm with fused epilogue * Base arguments and params structs for StreamK * StreamK epilogue with broadcast now inherits the original * undo params_streamk_base.h --------- Co-authored-by: Ali Hassani <ahassanijr@gmail.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
6fbc0d3380
commit
13f413493a
@ -33,3 +33,7 @@ cutlass_example_add_executable(
|
|||||||
ampere_gemm_universal_streamk.cu
|
ampere_gemm_universal_streamk.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cutlass_example_add_executable(
|
||||||
|
47_ampere_gemm_universal_streamk_broadcast
|
||||||
|
ampere_gemm_universal_streamk_broadcast.cu
|
||||||
|
)
|
||||||
|
|||||||
@ -495,7 +495,7 @@ int main(int argc, const char **argv)
|
|||||||
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
|
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
|
||||||
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
|
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
|
||||||
|
|
||||||
// Fill matrix A on host with uniform-random data [2, -2]
|
// Fill matrix A on host with uniform-random data [-2, 2]
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
options.tensor_a.host_view(),
|
options.tensor_a.host_view(),
|
||||||
1,
|
1,
|
||||||
@ -503,7 +503,7 @@ int main(int argc, const char **argv)
|
|||||||
ElementA(-2),
|
ElementA(-2),
|
||||||
0);
|
0);
|
||||||
|
|
||||||
// Fill matrix B on host with uniform-random data [2, -2]
|
// Fill matrix B on host with uniform-random data [-2, 2]
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
options.tensor_b.host_view(),
|
options.tensor_b.host_view(),
|
||||||
1,
|
1,
|
||||||
@ -511,7 +511,7 @@ int main(int argc, const char **argv)
|
|||||||
ElementB(-2),
|
ElementB(-2),
|
||||||
0);
|
0);
|
||||||
|
|
||||||
// Fill matrix C on host with uniform-random data [2, -2]
|
// Fill matrix C on host with uniform-random data [-2, 2]
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
options.tensor_c.host_view(),
|
options.tensor_c.host_view(),
|
||||||
1,
|
1,
|
||||||
|
|||||||
@ -0,0 +1,658 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/***************************************************************************************************
|
||||||
|
Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
|
||||||
|
"classic data-parallel" and "Split-K" decompositions + residual add.
|
||||||
|
|
||||||
|
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
|
||||||
|
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
|
||||||
|
|
||||||
|
Requires NVIDIA Ampere or newer device (SM80+).
|
||||||
|
|
||||||
|
- To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100)
|
||||||
|
|
||||||
|
cutlass$ sudo nvidia-smi -pm 1 -i 0
|
||||||
|
|
||||||
|
cutlass$ sudo nvidia-smi -i 0 -pl 400
|
||||||
|
|
||||||
|
cutlass$ sudo nvidia-smi -i 0 -lgc 1005
|
||||||
|
|
||||||
|
- Build and run:
|
||||||
|
|
||||||
|
cutlass$ mkdir build
|
||||||
|
|
||||||
|
cutlass$ cd build
|
||||||
|
|
||||||
|
cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80
|
||||||
|
|
||||||
|
cutlass/build$ make 47_ampere_gemm_universal_streamk_broadcast
|
||||||
|
|
||||||
|
cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk_broadcast
|
||||||
|
|
||||||
|
- Reset clocks when done:
|
||||||
|
|
||||||
|
cutlass$ sudo nvidia-smi -rgc
|
||||||
|
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal.h"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
|
||||||
|
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/command_line.h"
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/reference/device/gemm.h"
|
||||||
|
#include "cutlass/util/reference/host/error_metrics.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_foreach.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// A matrix configuration
|
||||||
|
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||||
|
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||||
|
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// B matrix configuration
|
||||||
|
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||||
|
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
|
||||||
|
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// C1/C2/D matrix configuration
|
||||||
|
using ElementC = cutlass::half_t; // Element type for C matrix operands
|
||||||
|
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
|
||||||
|
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// Output matrix configuration
|
||||||
|
using ElementOutput = cutlass::half_t; // Element type for output matrix operands
|
||||||
|
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
|
||||||
|
// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// Multiply-accumulate blocking/pipelining details
|
||||||
|
using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation
|
||||||
|
using ElementCompute = cutlass::half_t; // Element type for compute
|
||||||
|
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||||
|
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
|
||||||
|
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
|
||||||
|
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
|
||||||
|
|
||||||
|
// Residual block configuration
|
||||||
|
|
||||||
|
// Epilogue output operator
|
||||||
|
/// Using LinearCombinationResidualBlock
|
||||||
|
/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
|
||||||
|
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
||||||
|
ElementOutput, // Element type for output matrix
|
||||||
|
ElementAccumulator, // Element type from internal accumulation
|
||||||
|
ElementCompute, // Element type from internal accumulation
|
||||||
|
ElementC, // Element type for C1/C2/D matrix operands
|
||||||
|
AlignmentC, // Memory access granularity of C and D matrix in units of elements
|
||||||
|
cutlass::epilogue::thread::Identity, // Activation
|
||||||
|
cutlass::plus, // Binary operation 1
|
||||||
|
cutlass::epilogue::thread::Identity, // Unary operation
|
||||||
|
cutlass::plus // Binary operation 2
|
||||||
|
>;
|
||||||
|
|
||||||
|
// Reference device GEMM implementation type
|
||||||
|
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementAccumulator>;
|
||||||
|
|
||||||
|
// Classic data-parallel device GEMM implementation type
|
||||||
|
using DeviceGemmBasic = cutlass::gemm::device::GemmUniversalWithBroadcast<
|
||||||
|
ElementA, LayoutA,
|
||||||
|
ElementB, LayoutB,
|
||||||
|
ElementC, LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
OperatorClass,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOp,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
NumStages,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB>;
|
||||||
|
|
||||||
|
// StreamK device GEMM implementation type
|
||||||
|
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalStreamkWithBroadcast<
|
||||||
|
ElementA, LayoutA,
|
||||||
|
ElementB, LayoutB,
|
||||||
|
ElementC, LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
OperatorClass,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOp,
|
||||||
|
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||||
|
NumStages,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB>;
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Testbed utility types
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Result structure
|
||||||
|
struct Result
|
||||||
|
{
|
||||||
|
double avg_runtime_ms;
|
||||||
|
double gflops;
|
||||||
|
cutlass::Status status;
|
||||||
|
cudaError_t error;
|
||||||
|
bool passed;
|
||||||
|
|
||||||
|
Result(
|
||||||
|
double avg_runtime_ms = 0,
|
||||||
|
double gflops = 0,
|
||||||
|
cutlass::Status status = cutlass::Status::kSuccess,
|
||||||
|
cudaError_t error = cudaSuccess)
|
||||||
|
:
|
||||||
|
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
|
||||||
|
{}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Command line options parsing
|
||||||
|
struct Options
|
||||||
|
{
|
||||||
|
std::string command_name;
|
||||||
|
bool help;
|
||||||
|
cutlass::gemm::GemmCoord problem_size;
|
||||||
|
float alpha;
|
||||||
|
float beta;
|
||||||
|
int split_k_factor;
|
||||||
|
int avail_sms;
|
||||||
|
int iterations;
|
||||||
|
bool real;
|
||||||
|
|
||||||
|
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
|
||||||
|
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> tensor_c1;
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> tensor_c2;
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> tensor_Vector;
|
||||||
|
// cutlass::HostTensor<ElementC, LayoutC> tensor_Tensor;
|
||||||
|
|
||||||
|
Options(std::string command_name) :
|
||||||
|
command_name(command_name),
|
||||||
|
help(false),
|
||||||
|
problem_size({2048, 2048, 2048}),
|
||||||
|
alpha(1.0f),
|
||||||
|
beta(1.0f),
|
||||||
|
split_k_factor(1),
|
||||||
|
avail_sms(-1), // Number of device SMs to use is unlimited
|
||||||
|
real(false),
|
||||||
|
iterations(10000)
|
||||||
|
{}
|
||||||
|
|
||||||
|
bool valid() const
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void parse(int argc, char const **args)
|
||||||
|
{
|
||||||
|
cutlass::CommandLine cmd(argc, args);
|
||||||
|
|
||||||
|
if (cmd.check_cmd_line_flag("help")) {
|
||||||
|
help = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||||
|
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||||
|
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||||
|
cmd.get_cmd_line_argument("alpha", alpha);
|
||||||
|
cmd.get_cmd_line_argument("beta", beta);
|
||||||
|
cmd.get_cmd_line_argument("split", split_k_factor);
|
||||||
|
cmd.get_cmd_line_argument("iterations", iterations);
|
||||||
|
real = cmd.check_cmd_line_flag("real");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prints the usage statement.
|
||||||
|
std::ostream & print_usage(std::ostream &out) const
|
||||||
|
{
|
||||||
|
out
|
||||||
|
<< "Performs a GEMM computation.\n"
|
||||||
|
<< "\n"
|
||||||
|
<< "Options:\n"
|
||||||
|
<< "\n"
|
||||||
|
<< " --help If specified, displays this usage statement.\n\n"
|
||||||
|
<< " --m=<int> GEMM M dimension\n"
|
||||||
|
<< " --n=<int> GEMM N dimension\n"
|
||||||
|
<< " --k=<int> GEMM K dimension\n"
|
||||||
|
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||||
|
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||||
|
<< " --split=<int> Split-K factor to emulate\n\n"
|
||||||
|
<< " --real If specified, initializes with real values instead of whole numbers. Errors are to be expected.\n\n"
|
||||||
|
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||||
|
|
||||||
|
out
|
||||||
|
<< "\n\nExamples:\n\n"
|
||||||
|
<< "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute performance in GFLOP/s
|
||||||
|
double gflops(double runtime_s) const
|
||||||
|
{
|
||||||
|
// Two flops per multiply-add
|
||||||
|
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// GEMM evaluation
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options
|
||||||
|
typename DeviceGemmBasic::Arguments args_from_options(
|
||||||
|
const DeviceGemmBasic &device_gemm,
|
||||||
|
const Options &options,
|
||||||
|
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
|
||||||
|
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_c1,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_c2,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_d,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_Vector /*,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_Tensor */
|
||||||
|
)
|
||||||
|
{
|
||||||
|
return typename DeviceGemmBasic::Arguments(
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
|
||||||
|
options.problem_size, // problem_size
|
||||||
|
options.split_k_factor, // batch count / splitk slices
|
||||||
|
{ // epilogue parameters
|
||||||
|
ElementAccumulator(options.alpha),
|
||||||
|
ElementAccumulator(options.beta)
|
||||||
|
},
|
||||||
|
tensor_a.device_data(), // ptr_A
|
||||||
|
tensor_b.device_data(), // ptr_B
|
||||||
|
tensor_c1.device_data(), // ptr_C1
|
||||||
|
tensor_c2.device_data(), // ptr_C2
|
||||||
|
tensor_d.device_data(), // ptr_D
|
||||||
|
tensor_Vector.device_data(), // ptr_Vector
|
||||||
|
/* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor
|
||||||
|
options.problem_size.mk().product(), // batch_stride_A
|
||||||
|
options.problem_size.nk().product(), // batch_stride_B
|
||||||
|
options.problem_size.mn().product(), // batch_stride_C1
|
||||||
|
options.problem_size.mn().product(), // batch_stride_C2
|
||||||
|
options.problem_size.mn().product(), // batch_stride_D
|
||||||
|
options.problem_size.mn().product(), // batch_stride_Vector
|
||||||
|
options.problem_size.mn().product(), // batch_stride_Tensor
|
||||||
|
tensor_a.layout().stride(0), // stride_a
|
||||||
|
tensor_b.layout().stride(0), // stride_b
|
||||||
|
tensor_c1.layout().stride(0), // stride_c1
|
||||||
|
tensor_c2.layout().stride(0), // stride_c2
|
||||||
|
tensor_d.layout().stride(0), // stride_d
|
||||||
|
/*tensor_Vector.layout().stride(0)*/0, // stride_Vector
|
||||||
|
/*tensor_Tensor.layout().stride(0)*/0); // stride_Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
|
||||||
|
typename DeviceGemmStreamK::Arguments args_from_options(
|
||||||
|
const DeviceGemmStreamK &device_gemm,
|
||||||
|
const Options &options,
|
||||||
|
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
|
||||||
|
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_c1,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_c2,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_d,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_Vector/*,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_Tensor*/
|
||||||
|
)
|
||||||
|
{
|
||||||
|
return typename DeviceGemmStreamK::Arguments(
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
|
||||||
|
options.problem_size, // problem_size
|
||||||
|
options.split_k_factor, // batch count / splitk slices
|
||||||
|
{ // epilogue parameters
|
||||||
|
ElementAccumulator(options.alpha),
|
||||||
|
ElementAccumulator(options.beta)
|
||||||
|
},
|
||||||
|
tensor_a.device_data(), // ptr_A
|
||||||
|
tensor_b.device_data(), // ptr_B
|
||||||
|
tensor_c1.device_data(), // ptr_C1
|
||||||
|
tensor_c2.device_data(), // ptr_C2
|
||||||
|
tensor_d.device_data(), // ptr_D
|
||||||
|
tensor_Vector.device_data(), // ptr_Vector
|
||||||
|
/* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor // We're not storing Tensor
|
||||||
|
options.problem_size.mk().product(), // batch_stride_A
|
||||||
|
options.problem_size.nk().product(), // batch_stride_B
|
||||||
|
options.problem_size.mn().product(), // batch_stride_C1
|
||||||
|
options.problem_size.mn().product(), // batch_stride_C2
|
||||||
|
options.problem_size.mn().product(), // batch_stride_D
|
||||||
|
options.problem_size.mn().product(), // batch_stride_Vector
|
||||||
|
options.problem_size.mn().product(), // batch_stride_Tensor
|
||||||
|
tensor_a.layout().stride(0), // stride_a
|
||||||
|
tensor_b.layout().stride(0), // stride_b
|
||||||
|
tensor_c1.layout().stride(0), // stride_c1
|
||||||
|
tensor_c2.layout().stride(0), // stride_c2
|
||||||
|
tensor_d.layout().stride(0), // stride_d
|
||||||
|
/*tensor_Vector.layout().stride(0)*/0, // stride_Vector // Vector stride is always 0
|
||||||
|
/*tensor_Tensor.layout().stride(0)*/0, // stride_Tensor // We're not storing Tensor
|
||||||
|
options.avail_sms); // avail_sms
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a given example GEMM computation
|
||||||
|
template <typename DeviceGemmT>
|
||||||
|
Result run(std::string description, Options &options)
|
||||||
|
{
|
||||||
|
// Display test description
|
||||||
|
std::cout << std::endl << description << std::endl;
|
||||||
|
|
||||||
|
// Zero-initialize test output matrix D
|
||||||
|
cutlass::reference::host::TensorFill(options.tensor_d.host_view());
|
||||||
|
options.tensor_d.sync_device();
|
||||||
|
|
||||||
|
// Instantiate CUTLASS kernel depending on templates
|
||||||
|
DeviceGemmT device_gemm;
|
||||||
|
|
||||||
|
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
|
||||||
|
auto arguments = args_from_options(device_gemm, options,
|
||||||
|
options.tensor_a, options.tensor_b, options.tensor_c1, options.tensor_c2, options.tensor_d,
|
||||||
|
options.tensor_Vector/*, options.tensor_Tensor*/);
|
||||||
|
|
||||||
|
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||||
|
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
|
||||||
|
|
||||||
|
// Allocate workspace memory
|
||||||
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
// Check the problem size is supported or not
|
||||||
|
CUTLASS_CHECK(device_gemm.can_implement(arguments));
|
||||||
|
|
||||||
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||||
|
CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
|
||||||
|
|
||||||
|
// Correctness / Warmup iteration
|
||||||
|
CUTLASS_CHECK(device_gemm());
|
||||||
|
|
||||||
|
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||||
|
options.tensor_d.sync_host();
|
||||||
|
|
||||||
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||||
|
Result result;
|
||||||
|
result.passed = cutlass::reference::host::TensorEquals(
|
||||||
|
options.tensor_d.host_view(),
|
||||||
|
options.tensor_ref_d.host_view());
|
||||||
|
|
||||||
|
double err = cutlass::reference::host::TensorRelativeErrorMetric(
|
||||||
|
options.tensor_d.host_view(),
|
||||||
|
options.tensor_ref_d.host_view());
|
||||||
|
|
||||||
|
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl;
|
||||||
|
|
||||||
|
// Run profiling loop
|
||||||
|
if (options.iterations > 0)
|
||||||
|
{
|
||||||
|
GpuTimer timer;
|
||||||
|
timer.start();
|
||||||
|
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||||
|
CUTLASS_CHECK(device_gemm());
|
||||||
|
}
|
||||||
|
timer.stop();
|
||||||
|
|
||||||
|
// Compute average runtime and GFLOPs.
|
||||||
|
float elapsed_ms = timer.elapsed_millis();
|
||||||
|
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||||
|
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||||
|
|
||||||
|
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||||
|
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: uncomment when results match
|
||||||
|
//if (!result.passed) {
|
||||||
|
// exit(-1);
|
||||||
|
//}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Program entrypoint
|
||||||
|
int main(int argc, const char **argv)
|
||||||
|
{
|
||||||
|
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||||
|
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||||
|
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||||
|
|
||||||
|
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Current device must must have compute capability at least 80
|
||||||
|
cudaDeviceProp props;
|
||||||
|
int current_device_id;
|
||||||
|
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||||
|
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||||
|
if (!((props.major * 10 + props.minor) >= 80))
|
||||||
|
{
|
||||||
|
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||||
|
<< std::endl;
|
||||||
|
|
||||||
|
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse commandline options
|
||||||
|
Options options("ampere_streamk_broadcast_gemm");
|
||||||
|
options.parse(argc, argv);
|
||||||
|
|
||||||
|
if (options.help) {
|
||||||
|
options.print_usage(std::cout) << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout <<
|
||||||
|
options.iterations << " timing iterations of " <<
|
||||||
|
options.problem_size.m() << " x " <<
|
||||||
|
options.problem_size.n() << " x " <<
|
||||||
|
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
|
||||||
|
|
||||||
|
if (!options.valid()) {
|
||||||
|
std::cerr << "Invalid problem." << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// Initialize GEMM datasets
|
||||||
|
//
|
||||||
|
|
||||||
|
// Initialize tensors using CUTLASS helper functions
|
||||||
|
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||||
|
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||||
|
options.tensor_c1.resize(options.problem_size.mn()); // <- Create matrix C1 with dimensions M x N
|
||||||
|
options.tensor_c2.resize(options.problem_size.mn()); // <- Create matrix C2 with dimensions M x N
|
||||||
|
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
|
||||||
|
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
|
||||||
|
options.tensor_Vector.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions N x 1
|
||||||
|
// options.tensor_Tensor.resize(options.problem_size.mn()); // <- Create T matrix with dimensions M x N
|
||||||
|
|
||||||
|
int _init_bits = options.real ? -1 : 0;
|
||||||
|
|
||||||
|
// Fill matrix A on host with uniform-random data [-2, 2]
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
options.tensor_a.host_view(),
|
||||||
|
1,
|
||||||
|
ElementA(2),
|
||||||
|
ElementA(-2), _init_bits);
|
||||||
|
|
||||||
|
// Fill matrix B on host with uniform-random data [-2, 2]
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
options.tensor_b.host_view(),
|
||||||
|
1,
|
||||||
|
ElementB(2),
|
||||||
|
ElementB(-2), _init_bits);
|
||||||
|
|
||||||
|
// Fill matrix C1 on host with uniform-random data [-2, 2]
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
options.tensor_c1.host_view(),
|
||||||
|
1,
|
||||||
|
ElementC(2),
|
||||||
|
ElementC(-2), _init_bits);
|
||||||
|
|
||||||
|
// Fill matrix C2 on host with uniform-random data [-2, 2]
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
options.tensor_c2.host_view(),
|
||||||
|
1,
|
||||||
|
ElementC(2),
|
||||||
|
ElementC(-2), _init_bits);
|
||||||
|
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
options.tensor_Vector.host_view(),
|
||||||
|
1,
|
||||||
|
ElementC(2),
|
||||||
|
ElementC(-2), _init_bits);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Compute reference output
|
||||||
|
//
|
||||||
|
|
||||||
|
// Copy data from host to GPU
|
||||||
|
options.tensor_a.sync_device();
|
||||||
|
options.tensor_b.sync_device();
|
||||||
|
options.tensor_c1.sync_device();
|
||||||
|
options.tensor_c2.sync_device();
|
||||||
|
options.tensor_Vector.sync_device();
|
||||||
|
// options.tensor_Tensor.sync_device();
|
||||||
|
|
||||||
|
// Zero-initialize reference output matrix D
|
||||||
|
cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
|
||||||
|
options.tensor_ref_d.sync_device();
|
||||||
|
|
||||||
|
// Create instantiation for device reference gemm kernel
|
||||||
|
DeviceGemmReference gemm_reference;
|
||||||
|
|
||||||
|
// Launch device reference gemm kernel
|
||||||
|
gemm_reference(
|
||||||
|
options.problem_size,
|
||||||
|
ElementAccumulator(options.alpha),
|
||||||
|
options.tensor_a.device_ref(),
|
||||||
|
options.tensor_b.device_ref(),
|
||||||
|
ElementAccumulator(options.beta),
|
||||||
|
options.tensor_c1.device_ref(),
|
||||||
|
options.tensor_ref_d.device_ref());
|
||||||
|
|
||||||
|
// Wait for kernels to finish
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
// Copy output data from reference kernel to host for comparison
|
||||||
|
options.tensor_ref_d.sync_host();
|
||||||
|
|
||||||
|
// Add broadcast vector (without multiplier)
|
||||||
|
// This is only possible because BinaryOp is addition, and UnaryOps are identity.
|
||||||
|
// This makes the addition of broadcast vector commutable.
|
||||||
|
/// identity(plus(identity(alpha * (a * b) + v), beta * c)) ==
|
||||||
|
/// alpha * a * b + v + beta * c ==
|
||||||
|
/// (alpha * a * b + beta * c) + v ==
|
||||||
|
/// GEMM(a, b, c) + v
|
||||||
|
// Vector broadcast on host
|
||||||
|
for (int i=0; i < options.problem_size.m(); ++i) {
|
||||||
|
for (int j=0; j < options.problem_size.n(); ++j) {
|
||||||
|
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_Vector.host_view().ref().at({0, j});
|
||||||
|
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_c2.host_view().ref().at({i, j});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync back with device just in case
|
||||||
|
options.tensor_ref_d.sync_device();
|
||||||
|
|
||||||
|
//
|
||||||
|
// Evaluate CUTLASS kernels
|
||||||
|
//
|
||||||
|
|
||||||
|
// Test default operation
|
||||||
|
if (options.split_k_factor == 1)
|
||||||
|
{
|
||||||
|
// Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
|
||||||
|
Result basic_dp = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
|
||||||
|
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
|
||||||
|
|
||||||
|
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
|
||||||
|
|
||||||
|
// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
|
||||||
|
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
|
||||||
|
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
|
||||||
|
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
|
||||||
|
|
||||||
|
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
|
||||||
|
|
||||||
|
options.split_k_factor++; // Increment splitting factor for next evaluation
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show that StreamK can emulate "Split-K" with a tile-splitting factor
|
||||||
|
Result basic_splitk = run<DeviceGemmBasic>(
|
||||||
|
std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
|
||||||
|
options);
|
||||||
|
|
||||||
|
Result streamk_splitk = run<DeviceGemmStreamK>(
|
||||||
|
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
|
||||||
|
options);
|
||||||
|
|
||||||
|
printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@ -48,6 +48,7 @@
|
|||||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||||
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
|
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
|
||||||
|
#include "cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h"
|
||||||
|
|
||||||
#include "cutlass/layout/permute.h"
|
#include "cutlass/layout/permute.h"
|
||||||
|
|
||||||
@ -120,6 +121,67 @@ struct DefaultEpilogueWithBroadcastTensorOp {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Defines sensible defaults for streamk epilogues for TensorOps.
|
||||||
|
template <
|
||||||
|
typename Shape,
|
||||||
|
typename WarpMmaTensorOp,
|
||||||
|
int PartitionsK,
|
||||||
|
typename ElementOutput,
|
||||||
|
typename ElementTensor,
|
||||||
|
typename ElementVector,
|
||||||
|
typename OutputOp,
|
||||||
|
int ElementsPerAccess,
|
||||||
|
bool ScatterD = false,
|
||||||
|
typename PermuteDLayout = layout::NoPermute
|
||||||
|
>
|
||||||
|
struct DefaultStreamkEpilogueWithBroadcastTensorOp {
|
||||||
|
|
||||||
|
/// Use defaults related to the existing epilogue
|
||||||
|
using Base = DefaultEpilogueTensorOp<
|
||||||
|
Shape,
|
||||||
|
WarpMmaTensorOp,
|
||||||
|
PartitionsK,
|
||||||
|
OutputOp,
|
||||||
|
ElementsPerAccess
|
||||||
|
>;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Stores the result z = (y = GEMM(A, B, C), broadcast)
|
||||||
|
//
|
||||||
|
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||||
|
typename Base::OutputTileThreadMap,
|
||||||
|
ElementOutput,
|
||||||
|
ScatterD,
|
||||||
|
PermuteDLayout
|
||||||
|
>;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Additional tensor tile iterator - stores t = Elementwise(z)
|
||||||
|
//
|
||||||
|
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||||
|
typename Base::OutputTileThreadMap,
|
||||||
|
ElementTensor
|
||||||
|
>;
|
||||||
|
|
||||||
|
/// Define the epilogue
|
||||||
|
using Epilogue = EpilogueStreamkWithBroadcast<
|
||||||
|
Shape,
|
||||||
|
WarpMmaTensorOp,
|
||||||
|
PartitionsK,
|
||||||
|
OutputTileIterator,
|
||||||
|
TensorTileIterator,
|
||||||
|
ElementVector,
|
||||||
|
typename Base::AccumulatorFragmentIterator,
|
||||||
|
typename Base::WarpTileIterator,
|
||||||
|
typename Base::SharedLoadIterator,
|
||||||
|
OutputOp,
|
||||||
|
typename Base::Padding,
|
||||||
|
Base::kFragmentsPerIteration
|
||||||
|
>;
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Defines sensible defaults for epilogues for VoltaTensorOps.
|
/// Defines sensible defaults for epilogues for VoltaTensorOps.
|
||||||
template <
|
template <
|
||||||
typename Shape,
|
typename Shape,
|
||||||
|
|||||||
@ -0,0 +1,443 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
|
||||||
|
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||||
|
|
||||||
|
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||||
|
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(__CUDACC_RTC__)
|
||||||
|
#include <cuda/std/cassert>
|
||||||
|
#include <cuda/std/utility>
|
||||||
|
#else
|
||||||
|
#include <assert.h>
|
||||||
|
#include <utility>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/array.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
|
#include "cutlass/tensor_coord.h"
|
||||||
|
#include "cutlass/aligned_buffer.h"
|
||||||
|
#include "cutlass/functional.h"
|
||||||
|
#include "cutlass/fast_math.h"
|
||||||
|
#include "cutlass/layout/vector.h"
|
||||||
|
#include "cutlass/layout/tensor.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/gemm.h"
|
||||||
|
|
||||||
|
#include "cutlass/transform/pitch_linear_thread_map.h"
|
||||||
|
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/threadblock/epilogue_base.h"
|
||||||
|
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
|
||||||
|
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||||
|
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace epilogue {
|
||||||
|
namespace threadblock {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// This base class is meant to define the concept required of the
|
||||||
|
/// EpilogueStreamkWithBroadcast::OutputOp
|
||||||
|
template <
|
||||||
|
typename ElementC_,
|
||||||
|
typename ElementAccumulator_,
|
||||||
|
typename ElementCompute_,
|
||||||
|
typename ElementZ_,
|
||||||
|
typename ElementT_,
|
||||||
|
int ElementsPerAccess,
|
||||||
|
bool StoreZ = true,
|
||||||
|
bool StoreT = true
|
||||||
|
>
|
||||||
|
struct EpilogueStreamkWithBroadcastOpBase : EpilogueWithBroadcastOpBase<
|
||||||
|
ElementC_,
|
||||||
|
ElementAccumulator_,
|
||||||
|
ElementCompute_,
|
||||||
|
ElementZ_,
|
||||||
|
ElementT_,
|
||||||
|
ElementsPerAccess,
|
||||||
|
StoreZ,
|
||||||
|
StoreT
|
||||||
|
>
|
||||||
|
{
|
||||||
|
|
||||||
|
/// Parameters structure - required
|
||||||
|
struct Params { };
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Constructor from Params
|
||||||
|
EpilogueStreamkWithBroadcastOpBase(Params const ¶ms_) { }
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Epilogue operator with bias vector broadcast over columns.
|
||||||
|
///
|
||||||
|
/// Computes the following:
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// Z, T = OutputOp(AB, C, Broadcast)
|
||||||
|
///
|
||||||
|
/// if (ElementwiseOp::kStoreZ) {
|
||||||
|
/// store(converted_u);
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// if (ElementwiseOp::kStoreT) {
|
||||||
|
/// store(v);
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
template <
|
||||||
|
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
|
||||||
|
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||||
|
int PartitionsK, ///< Number of partitions of the K dimension
|
||||||
|
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
|
||||||
|
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
|
||||||
|
typename ElementVector_, ///< Pointer to broadcast vector
|
||||||
|
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
|
||||||
|
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
|
||||||
|
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
|
||||||
|
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
|
||||||
|
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
|
||||||
|
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
|
||||||
|
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
|
||||||
|
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
|
||||||
|
bool IsSingleSource = OutputOp_::kIsSingleSource
|
||||||
|
>
|
||||||
|
class EpilogueStreamkWithBroadcast;
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// EpilogueStreamkWithBroadcast: Two sources
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename Shape_,
|
||||||
|
typename WarpMmaOperator_,
|
||||||
|
int PartitionsK,
|
||||||
|
typename OutputTileIterator_,
|
||||||
|
typename TensorTileIterator_,
|
||||||
|
typename ElementVector_,
|
||||||
|
typename AccumulatorFragmentIterator_,
|
||||||
|
typename WarpTileIterator_,
|
||||||
|
typename SharedLoadIterator_,
|
||||||
|
typename OutputOp_,
|
||||||
|
typename Padding_,
|
||||||
|
int FragmentsPerPartition,
|
||||||
|
int IterationsUnroll
|
||||||
|
>
|
||||||
|
class EpilogueStreamkWithBroadcast<
|
||||||
|
Shape_,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
PartitionsK,
|
||||||
|
OutputTileIterator_,
|
||||||
|
TensorTileIterator_,
|
||||||
|
ElementVector_,
|
||||||
|
AccumulatorFragmentIterator_,
|
||||||
|
WarpTileIterator_,
|
||||||
|
SharedLoadIterator_,
|
||||||
|
OutputOp_,
|
||||||
|
Padding_,
|
||||||
|
FragmentsPerPartition,
|
||||||
|
IterationsUnroll,
|
||||||
|
false
|
||||||
|
> :
|
||||||
|
public EpilogueWithBroadcast<
|
||||||
|
Shape_,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
PartitionsK,
|
||||||
|
OutputTileIterator_,
|
||||||
|
TensorTileIterator_,
|
||||||
|
ElementVector_,
|
||||||
|
AccumulatorFragmentIterator_,
|
||||||
|
WarpTileIterator_,
|
||||||
|
SharedLoadIterator_,
|
||||||
|
OutputOp_,
|
||||||
|
Padding_,
|
||||||
|
FragmentsPerPartition,
|
||||||
|
IterationsUnroll,
|
||||||
|
false>,
|
||||||
|
public EpilogueBaseStreamK<
|
||||||
|
Shape_,
|
||||||
|
PartitionsK,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
AccumulatorFragmentIterator_>
|
||||||
|
{
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
using Base = EpilogueWithBroadcast<
|
||||||
|
Shape_,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
PartitionsK,
|
||||||
|
OutputTileIterator_,
|
||||||
|
TensorTileIterator_,
|
||||||
|
ElementVector_,
|
||||||
|
AccumulatorFragmentIterator_,
|
||||||
|
WarpTileIterator_,
|
||||||
|
SharedLoadIterator_,
|
||||||
|
OutputOp_,
|
||||||
|
Padding_,
|
||||||
|
FragmentsPerPartition,
|
||||||
|
IterationsUnroll,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
using BaseStreamK = EpilogueBaseStreamK<
|
||||||
|
Shape_,
|
||||||
|
PartitionsK,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
AccumulatorFragmentIterator_>;
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
static int const kPartitionsK = PartitionsK;
|
||||||
|
using OutputTileIterator = OutputTileIterator_;
|
||||||
|
using TensorTileIterator = TensorTileIterator_;
|
||||||
|
using ElementVector = ElementVector_;
|
||||||
|
using SharedLoadIterator = SharedLoadIterator_;
|
||||||
|
using OutputOp = OutputOp_;
|
||||||
|
|
||||||
|
/// Fragment type used by the accumulator tile's fragment iterator
|
||||||
|
using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment;
|
||||||
|
|
||||||
|
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
||||||
|
using SharedStorage = typename Base::SharedStorage;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Constructor
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
EpilogueStreamkWithBroadcast(
|
||||||
|
SharedStorage &shared_storage, ///< Shared storage object
|
||||||
|
int thread_idx, ///< ID of a thread within the threadblock
|
||||||
|
int warp_idx, ///< ID of warp within threadblock
|
||||||
|
int lane_idx ///< Id of thread within warp
|
||||||
|
):
|
||||||
|
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||||
|
BaseStreamK(thread_idx)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
|
||||||
|
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
|
||||||
|
/// performing epilogue computations, writing to output
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void reduce(
|
||||||
|
int peer_idx_begin,
|
||||||
|
int peer_idx_end,
|
||||||
|
int reduce_fragment_idx,
|
||||||
|
void *element_workspace,
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
|
||||||
|
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
|
||||||
|
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||||
|
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||||
|
MatrixCoord(Shape::kM, Shape::kN),
|
||||||
|
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||||
|
MatrixCoord())
|
||||||
|
{
|
||||||
|
// Reduce peer accumulator fragments into one fragment
|
||||||
|
AccumulatorFragment accum_fragment;
|
||||||
|
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
|
||||||
|
|
||||||
|
// Store fragment to shared memory
|
||||||
|
this->warp_tile_iterator_.store(accum_fragment);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator1, source_iterator2, tensor_iterator, problem_size, threadblock_offset);
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// EpilogueStreamkWithBroadcast: Single source
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename Shape_,
|
||||||
|
typename WarpMmaOperator_,
|
||||||
|
int PartitionsK,
|
||||||
|
typename OutputTileIterator_,
|
||||||
|
typename TensorTileIterator_,
|
||||||
|
typename ElementVector_,
|
||||||
|
typename AccumulatorFragmentIterator_,
|
||||||
|
typename WarpTileIterator_,
|
||||||
|
typename SharedLoadIterator_,
|
||||||
|
typename OutputOp_,
|
||||||
|
typename Padding_,
|
||||||
|
int FragmentsPerPartition,
|
||||||
|
int IterationsUnroll
|
||||||
|
>
|
||||||
|
class EpilogueStreamkWithBroadcast<
|
||||||
|
Shape_,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
PartitionsK,
|
||||||
|
OutputTileIterator_,
|
||||||
|
TensorTileIterator_,
|
||||||
|
ElementVector_,
|
||||||
|
AccumulatorFragmentIterator_,
|
||||||
|
WarpTileIterator_,
|
||||||
|
SharedLoadIterator_,
|
||||||
|
OutputOp_,
|
||||||
|
Padding_,
|
||||||
|
FragmentsPerPartition,
|
||||||
|
IterationsUnroll,
|
||||||
|
true
|
||||||
|
> :
|
||||||
|
public EpilogueWithBroadcast<
|
||||||
|
Shape_,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
PartitionsK,
|
||||||
|
OutputTileIterator_,
|
||||||
|
TensorTileIterator_,
|
||||||
|
ElementVector_,
|
||||||
|
AccumulatorFragmentIterator_,
|
||||||
|
WarpTileIterator_,
|
||||||
|
SharedLoadIterator_,
|
||||||
|
OutputOp_,
|
||||||
|
Padding_,
|
||||||
|
FragmentsPerPartition,
|
||||||
|
IterationsUnroll,
|
||||||
|
true>,
|
||||||
|
public EpilogueBaseStreamK<
|
||||||
|
Shape_,
|
||||||
|
PartitionsK,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
AccumulatorFragmentIterator_>
|
||||||
|
{
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
using Base = EpilogueWithBroadcast<
|
||||||
|
Shape_,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
PartitionsK,
|
||||||
|
OutputTileIterator_,
|
||||||
|
TensorTileIterator_,
|
||||||
|
ElementVector_,
|
||||||
|
AccumulatorFragmentIterator_,
|
||||||
|
WarpTileIterator_,
|
||||||
|
SharedLoadIterator_,
|
||||||
|
OutputOp_,
|
||||||
|
Padding_,
|
||||||
|
FragmentsPerPartition,
|
||||||
|
IterationsUnroll,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
using BaseStreamK = EpilogueBaseStreamK<
|
||||||
|
Shape_,
|
||||||
|
PartitionsK,
|
||||||
|
WarpMmaOperator_,
|
||||||
|
AccumulatorFragmentIterator_>;
|
||||||
|
|
||||||
|
using Shape = Shape_;
|
||||||
|
static int const kPartitionsK = PartitionsK;
|
||||||
|
using OutputTileIterator = OutputTileIterator_;
|
||||||
|
using TensorTileIterator = TensorTileIterator_;
|
||||||
|
using ElementVector = ElementVector_;
|
||||||
|
using SharedLoadIterator = SharedLoadIterator_;
|
||||||
|
using OutputOp = OutputOp_;
|
||||||
|
|
||||||
|
/// Fragment type used by the accumulator tile's fragment iterator
|
||||||
|
using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment;
|
||||||
|
|
||||||
|
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
|
||||||
|
using SharedStorage = typename Base::SharedStorage;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Constructor
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
EpilogueStreamkWithBroadcast(
|
||||||
|
SharedStorage &shared_storage, ///< Shared storage object
|
||||||
|
int thread_idx, ///< ID of a thread within the threadblock
|
||||||
|
int warp_idx, ///< ID of warp within threadblock
|
||||||
|
int lane_idx ///< Id of thread within warp
|
||||||
|
):
|
||||||
|
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||||
|
BaseStreamK(thread_idx)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
|
||||||
|
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
|
||||||
|
/// performing epilogue computations, writing to output
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void reduce(
|
||||||
|
int peer_idx_begin,
|
||||||
|
int peer_idx_end,
|
||||||
|
int reduce_fragment_idx,
|
||||||
|
void *element_workspace,
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||||
|
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||||
|
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||||
|
MatrixCoord(Shape::kM, Shape::kN),
|
||||||
|
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||||
|
MatrixCoord())
|
||||||
|
{
|
||||||
|
// Reduce peer accumulator fragments into one fragment
|
||||||
|
AccumulatorFragment accum_fragment;
|
||||||
|
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
|
||||||
|
|
||||||
|
// Store fragment to shared memory
|
||||||
|
this->warp_tile_iterator_.store(accum_fragment);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator, tensor_iterator, problem_size, threadblock_offset);
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace threadblock
|
||||||
|
} // namespace epilogue
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -863,6 +863,98 @@ private:
|
|||||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Stream-K reduce helper
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void reduce(
|
||||||
|
int reduce_fragment_idx, ///< Reduce fragment index
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
|
||||||
|
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
|
||||||
|
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||||
|
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||||
|
MatrixCoord(Shape::kM, Shape::kN),
|
||||||
|
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||||
|
MatrixCoord())
|
||||||
|
{
|
||||||
|
|
||||||
|
BroadcastFragment broadcast_fragment;
|
||||||
|
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
||||||
|
|
||||||
|
// Initialize/load source-fragment data
|
||||||
|
typename OutputTileIterator::Fragment source_fragment1;
|
||||||
|
source_fragment1.clear();
|
||||||
|
typename OutputTileIterator::Fragment source_fragment2;
|
||||||
|
source_fragment2.clear();
|
||||||
|
|
||||||
|
if (output_op.is_source_needed())
|
||||||
|
{
|
||||||
|
source_iterator1 += reduce_fragment_idx;
|
||||||
|
source_iterator1.load(source_fragment1);
|
||||||
|
|
||||||
|
source_iterator2 += reduce_fragment_idx;
|
||||||
|
source_iterator2.load(source_fragment2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load fragment from shared memory
|
||||||
|
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||||
|
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||||
|
|
||||||
|
// Add fragments shared by other k partitions
|
||||||
|
if (kPartitionsK > 1)
|
||||||
|
{
|
||||||
|
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||||
|
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||||
|
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||||
|
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Apply output operation
|
||||||
|
//
|
||||||
|
|
||||||
|
typename OutputTileIterator::Fragment frag_Z;
|
||||||
|
typename TensorTileIterator::Fragment frag_T;
|
||||||
|
|
||||||
|
if (!output_op.is_source_needed()) {
|
||||||
|
apply_output_operator_source_not_needed_(
|
||||||
|
frag_Z,
|
||||||
|
frag_T,
|
||||||
|
output_op,
|
||||||
|
aligned_accum_fragment[0],
|
||||||
|
broadcast_fragment);
|
||||||
|
} else {
|
||||||
|
apply_output_operator_(
|
||||||
|
frag_Z,
|
||||||
|
frag_T,
|
||||||
|
output_op,
|
||||||
|
aligned_accum_fragment[0],
|
||||||
|
source_fragment1,
|
||||||
|
source_fragment2,
|
||||||
|
broadcast_fragment);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Conditionally store fragments
|
||||||
|
//
|
||||||
|
|
||||||
|
if (OutputOp::kStoreZ) {
|
||||||
|
destination_iterator.store(frag_Z);
|
||||||
|
++destination_iterator;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (OutputOp::kStoreT) {
|
||||||
|
tensor_iterator.store(frag_T);
|
||||||
|
++tensor_iterator;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -1529,6 +1621,92 @@ private:
|
|||||||
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Stream-K reduce helper
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void reduce(
|
||||||
|
int reduce_fragment_idx, ///< Reduce fragment index
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
ElementVector const * broadcast_ptr, ///< Broadcast vector
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||||
|
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
|
||||||
|
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
|
||||||
|
MatrixCoord(Shape::kM, Shape::kN),
|
||||||
|
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
|
||||||
|
MatrixCoord())
|
||||||
|
{
|
||||||
|
|
||||||
|
BroadcastFragment broadcast_fragment;
|
||||||
|
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
|
||||||
|
|
||||||
|
// Initialize/load source-fragment data
|
||||||
|
typename OutputTileIterator::Fragment source_fragment;
|
||||||
|
source_fragment.clear();
|
||||||
|
|
||||||
|
if (output_op.is_source_needed())
|
||||||
|
{
|
||||||
|
source_iterator += reduce_fragment_idx;
|
||||||
|
source_iterator.load(source_fragment);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load fragment from shared memory
|
||||||
|
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||||
|
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||||
|
|
||||||
|
// Add fragments shared by other k partitions
|
||||||
|
if (kPartitionsK > 1)
|
||||||
|
{
|
||||||
|
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||||
|
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||||
|
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||||
|
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Apply output operation
|
||||||
|
//
|
||||||
|
|
||||||
|
typename OutputTileIterator::Fragment frag_Z;
|
||||||
|
typename TensorTileIterator::Fragment frag_T;
|
||||||
|
|
||||||
|
if (!output_op.is_source_needed()) {
|
||||||
|
apply_output_operator_source_not_needed_(
|
||||||
|
frag_Z,
|
||||||
|
frag_T,
|
||||||
|
output_op,
|
||||||
|
aligned_accum_fragment[0],
|
||||||
|
broadcast_fragment);
|
||||||
|
} else {
|
||||||
|
apply_output_operator_(
|
||||||
|
frag_Z,
|
||||||
|
frag_T,
|
||||||
|
output_op,
|
||||||
|
aligned_accum_fragment[0],
|
||||||
|
source_fragment,
|
||||||
|
broadcast_fragment);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Conditionally store fragments
|
||||||
|
//
|
||||||
|
|
||||||
|
if (OutputOp::kStoreZ) {
|
||||||
|
destination_iterator.store(frag_Z);
|
||||||
|
++destination_iterator;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (OutputOp::kStoreT) {
|
||||||
|
tensor_iterator.store(frag_T);
|
||||||
|
++tensor_iterator;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -0,0 +1,386 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/*! \file
|
||||||
|
\brief Template for a Stream-K GEMM kernel that can broadcast bias vector in the
|
||||||
|
epilogue.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
#include "cutlass/arch/arch.h"
|
||||||
|
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
|
||||||
|
#include "cutlass/device_kernel.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/gemm.h"
|
||||||
|
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||||
|
#include "cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h"
|
||||||
|
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace device {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/*!
|
||||||
|
The universal GEMM with a broadcast epilogue.
|
||||||
|
Supports
|
||||||
|
*/
|
||||||
|
template <
|
||||||
|
/// Element type for A matrix operand
|
||||||
|
typename ElementA_,
|
||||||
|
/// Layout type for A matrix operand
|
||||||
|
typename LayoutA_,
|
||||||
|
/// Element type for B matrix operand
|
||||||
|
typename ElementB_,
|
||||||
|
/// Layout type for B matrix operand
|
||||||
|
typename LayoutB_,
|
||||||
|
/// Element type for C and D matrix operands
|
||||||
|
typename ElementC_,
|
||||||
|
/// Layout type for C and D matrix operands
|
||||||
|
typename LayoutC_,
|
||||||
|
/// Element type for internal accumulation
|
||||||
|
typename ElementAccumulator_ = ElementC_,
|
||||||
|
/// Operator class tag
|
||||||
|
typename OperatorClass_ = arch::OpClassSimt,
|
||||||
|
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||||
|
/// supports the intended feature. The device kernel can be built
|
||||||
|
/// targeting any SM larger than this number.
|
||||||
|
typename ArchTag_ = arch::Sm70,
|
||||||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||||||
|
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
|
||||||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
|
ElementAccumulator_>::ThreadblockShape,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename WarpShape_ = typename DefaultGemmConfiguration<
|
||||||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
|
ElementAccumulator_>::WarpShape,
|
||||||
|
/// Instruction-level tile size (concept: GemmShape)
|
||||||
|
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
|
ElementAccumulator_>::InstructionShape,
|
||||||
|
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||||
|
typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
|
||||||
|
ElementC_, ElementAccumulator_, ElementAccumulator_,
|
||||||
|
ElementC_, ElementC_, 128 / cutlass::sizeof_bits<ElementC_>::value>,
|
||||||
|
/// Threadblock-level swizzling operator
|
||||||
|
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
/// Number of stages used in the pipelined mainloop
|
||||||
|
int Stages =
|
||||||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||||
|
ElementC_, ElementAccumulator_>::kStages,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int AlignmentA =
|
||||||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||||
|
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||||
|
/// Access granularity of B matrix in units of elements
|
||||||
|
int AlignmentB =
|
||||||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||||
|
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||||
|
/// Operation performed by GEMM
|
||||||
|
typename Operator_ = typename DefaultGemmConfiguration<
|
||||||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
|
ElementAccumulator_>::Operator,
|
||||||
|
/// Complex elementwise transformation on A operand
|
||||||
|
ComplexTransform TransformA = ComplexTransform::kNone,
|
||||||
|
/// Complex elementwise transformation on B operand
|
||||||
|
ComplexTransform TransformB = ComplexTransform::kNone
|
||||||
|
>
|
||||||
|
class GemmUniversalStreamkWithBroadcast :
|
||||||
|
public GemmUniversalBase<
|
||||||
|
typename kernel::DefaultGemmStreamkWithBroadcast<
|
||||||
|
ElementA_,
|
||||||
|
LayoutA_,
|
||||||
|
TransformA,
|
||||||
|
AlignmentA,
|
||||||
|
ElementB_,
|
||||||
|
LayoutB_,
|
||||||
|
TransformB,
|
||||||
|
AlignmentB,
|
||||||
|
ElementC_,
|
||||||
|
LayoutC_,
|
||||||
|
ElementAccumulator_,
|
||||||
|
OperatorClass_,
|
||||||
|
ArchTag_,
|
||||||
|
ThreadblockShape_,
|
||||||
|
WarpShape_,
|
||||||
|
InstructionShape_,
|
||||||
|
EpilogueOutputOp_,
|
||||||
|
ThreadblockSwizzle_,
|
||||||
|
Stages,
|
||||||
|
Operator_
|
||||||
|
>::GemmKernel
|
||||||
|
> {
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
|
using OperatorClass = OperatorClass_;
|
||||||
|
using ArchTag = ArchTag_;
|
||||||
|
using ThreadblockShape = ThreadblockShape_;
|
||||||
|
using WarpShape = WarpShape_;
|
||||||
|
using InstructionShape = InstructionShape_;
|
||||||
|
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||||
|
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||||
|
using Operator = Operator_;
|
||||||
|
static int const kStages = Stages;
|
||||||
|
static int const kAlignmentA = AlignmentA;
|
||||||
|
static int const kAlignmentB = AlignmentB;
|
||||||
|
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||||
|
static ComplexTransform const kTransformA = TransformA;
|
||||||
|
static ComplexTransform const kTransformB = TransformB;
|
||||||
|
|
||||||
|
using Base = GemmUniversalBase<
|
||||||
|
typename kernel::DefaultGemmStreamkWithBroadcast<
|
||||||
|
ElementA_,
|
||||||
|
LayoutA_,
|
||||||
|
TransformA,
|
||||||
|
AlignmentA,
|
||||||
|
ElementB_,
|
||||||
|
LayoutB_,
|
||||||
|
TransformB,
|
||||||
|
AlignmentB,
|
||||||
|
ElementC_,
|
||||||
|
LayoutC_,
|
||||||
|
ElementAccumulator_,
|
||||||
|
OperatorClass_,
|
||||||
|
ArchTag_,
|
||||||
|
ThreadblockShape_,
|
||||||
|
WarpShape_,
|
||||||
|
InstructionShape_,
|
||||||
|
EpilogueOutputOp_,
|
||||||
|
ThreadblockSwizzle_,
|
||||||
|
Stages,
|
||||||
|
Operator_
|
||||||
|
>::GemmKernel
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Arguments = typename Base::Arguments;
|
||||||
|
using GemmKernel = typename Base::GemmKernel;
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Partial specialization for column-major output exchanges problem size and operand.
|
||||||
|
template <
|
||||||
|
/// Element type for A matrix operand
|
||||||
|
typename ElementA_,
|
||||||
|
/// Layout type for A matrix operand
|
||||||
|
typename LayoutA_,
|
||||||
|
/// Element type for B matrix operand
|
||||||
|
typename ElementB_,
|
||||||
|
/// Layout type for B matrix operand
|
||||||
|
typename LayoutB_,
|
||||||
|
/// Element type for C and D matrix operands
|
||||||
|
typename ElementC_,
|
||||||
|
/// Element type for internal accumulation
|
||||||
|
typename ElementAccumulator_,
|
||||||
|
/// Operator class tag
|
||||||
|
typename OperatorClass_,
|
||||||
|
/// Tag indicating architecture to tune for. This is the minimum SM that
|
||||||
|
/// supports the intended feature. The device kernel can be built
|
||||||
|
/// targeting any SM larger than this number.
|
||||||
|
typename ArchTag_,
|
||||||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||||||
|
typename ThreadblockShape_,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename WarpShape_,
|
||||||
|
/// Instruction-level tile size (concept: GemmShape)
|
||||||
|
typename InstructionShape_,
|
||||||
|
/// Epilogue output operator
|
||||||
|
typename EpilogueOutputOp_,
|
||||||
|
/// Threadblock-level swizzling operator
|
||||||
|
typename ThreadblockSwizzle_,
|
||||||
|
/// Number of stages used in the pipelined mainloop
|
||||||
|
int Stages,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int AlignmentA,
|
||||||
|
/// Access granularity of B matrix in units of elements
|
||||||
|
int AlignmentB,
|
||||||
|
/// Operation performed by GEMM
|
||||||
|
typename Operator_,
|
||||||
|
/// Complex elementwise transformation on A operand
|
||||||
|
ComplexTransform TransformA,
|
||||||
|
/// Complex elementwise transformation on B operand
|
||||||
|
ComplexTransform TransformB>
|
||||||
|
class GemmUniversalStreamkWithBroadcast<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
|
||||||
|
layout::ColumnMajor, // partially specialized on LayoutC
|
||||||
|
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
|
||||||
|
WarpShape_, InstructionShape_, EpilogueOutputOp_,
|
||||||
|
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
|
||||||
|
Operator_, TransformA, TransformB> {
|
||||||
|
public:
|
||||||
|
|
||||||
|
using ElementA = ElementA_;
|
||||||
|
using LayoutA = LayoutA_;
|
||||||
|
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||||
|
using ElementB = ElementB_;
|
||||||
|
using LayoutB = LayoutB_;
|
||||||
|
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||||
|
using ElementC = ElementC_;
|
||||||
|
using LayoutC = layout::ColumnMajor;
|
||||||
|
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||||
|
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||||
|
using ElementAccumulator = ElementAccumulator_;
|
||||||
|
using OperatorClass = OperatorClass_;
|
||||||
|
using ArchTag = ArchTag_;
|
||||||
|
using ThreadblockShape = ThreadblockShape_;
|
||||||
|
using WarpShape = WarpShape_;
|
||||||
|
using InstructionShape = InstructionShape_;
|
||||||
|
using EpilogueOutputOp = EpilogueOutputOp_;
|
||||||
|
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||||
|
using Operator = Operator_;
|
||||||
|
static int const kStages = Stages;
|
||||||
|
static int const kAlignmentA = AlignmentA;
|
||||||
|
static int const kAlignmentB = AlignmentB;
|
||||||
|
static ComplexTransform const kTransformA = TransformA;
|
||||||
|
static ComplexTransform const kTransformB = TransformB;
|
||||||
|
|
||||||
|
using UnderlyingOperator = typename GemmUniversalStreamkWithBroadcast<
|
||||||
|
ElementB,
|
||||||
|
typename layout::LayoutTranspose<LayoutB>::type,
|
||||||
|
ElementA,
|
||||||
|
typename layout::LayoutTranspose<LayoutA>::type,
|
||||||
|
ElementC,
|
||||||
|
layout::RowMajor,
|
||||||
|
ElementAccumulator,
|
||||||
|
OperatorClass,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
Stages,
|
||||||
|
kAlignmentB,
|
||||||
|
kAlignmentA,
|
||||||
|
Operator,
|
||||||
|
kTransformB,
|
||||||
|
kTransformA
|
||||||
|
>::Base;
|
||||||
|
|
||||||
|
using GemmKernel = typename UnderlyingOperator::GemmKernel;
|
||||||
|
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||||
|
|
||||||
|
/// Argument structure
|
||||||
|
using Arguments = typename UnderlyingOperator::Arguments;
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
UnderlyingOperator underlying_operator_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Constructs the GEMM.
|
||||||
|
GemmUniversalStreamkWithBroadcast() { }
|
||||||
|
|
||||||
|
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||||
|
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||||
|
return args.transposed_problem();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determines whether the GEMM can execute the given problem.
|
||||||
|
static Status can_implement(Arguments const &args) {
|
||||||
|
|
||||||
|
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the workspace size
|
||||||
|
static size_t get_workspace_size(Arguments const &args) {
|
||||||
|
|
||||||
|
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the grid shape
|
||||||
|
static dim3 get_grid_shape(Arguments const &args) {
|
||||||
|
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the maximum number of active blocks per multiprocessor
|
||||||
|
static int maximum_active_blocks(int smem_capacity = -1) {
|
||||||
|
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initializes GEMM state from arguments.
|
||||||
|
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||||
|
|
||||||
|
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lightweight update given a subset of arguments
|
||||||
|
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||||
|
|
||||||
|
return underlying_operator_.update(to_underlying_arguments(args), workspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs the kernel using initialized state.
|
||||||
|
Status run(cudaStream_t stream = nullptr) {
|
||||||
|
|
||||||
|
return underlying_operator_.run(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs the kernel using initialized state.
|
||||||
|
Status operator()(cudaStream_t stream = nullptr) {
|
||||||
|
return run(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs the kernel using initialized state.
|
||||||
|
Status operator()(
|
||||||
|
Arguments const &args,
|
||||||
|
void *workspace = nullptr,
|
||||||
|
cudaStream_t stream = nullptr) {
|
||||||
|
|
||||||
|
Status status = initialize(args, workspace, stream);
|
||||||
|
|
||||||
|
if (status == Status::kSuccess) {
|
||||||
|
status = run(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace device
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -0,0 +1,146 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/*! \file
|
||||||
|
\brief
|
||||||
|
Defines a Stream-K GEMM that can broadcast a bias vector in the epilogue.
|
||||||
|
Similar structure to DefaultGemmWithBroadcast, but uses its own epilogue
|
||||||
|
(DefaultStreamkEpilogueWithBroadcastTensorOp) and its own GEMM kernel
|
||||||
|
(GemmStreamkWithFusedEpilogue).
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h"
|
||||||
|
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||||
|
|
||||||
|
#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h"
|
||||||
|
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
/// Element type for A matrix operand
|
||||||
|
typename ElementA_,
|
||||||
|
/// Layout type for A matrix operand
|
||||||
|
typename LayoutA_,
|
||||||
|
/// Complex elementwise transformation on A operand
|
||||||
|
ComplexTransform TransformA,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int kAlignmentA,
|
||||||
|
/// Element type for B matrix operand
|
||||||
|
typename ElementB_,
|
||||||
|
/// Layout type for B matrix operand
|
||||||
|
typename LayoutB_,
|
||||||
|
/// Complex elementwise transformation on B operand
|
||||||
|
ComplexTransform TransformB,
|
||||||
|
/// Access granularity of B matrix in units of elements
|
||||||
|
int kAlignmentB,
|
||||||
|
/// Element type for C and D matrix operands
|
||||||
|
typename ElementC_,
|
||||||
|
/// Layout type for C and D matrix operands
|
||||||
|
typename LayoutC_,
|
||||||
|
/// Element type for internal accumulation
|
||||||
|
typename ElementAccumulator,
|
||||||
|
/// Operator class tag
|
||||||
|
typename OperatorClass,
|
||||||
|
/// Tag indicating architecture to tune for
|
||||||
|
typename ArchTag,
|
||||||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||||||
|
typename ThreadblockShape,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename WarpShape,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename InstructionShape,
|
||||||
|
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
|
||||||
|
typename EpilogueOutputOp,
|
||||||
|
/// Threadblock-level swizzling operator
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
/// Number of stages used in the pipelined mainloop
|
||||||
|
int Stages,
|
||||||
|
/// Operation performed by GEMM
|
||||||
|
typename Operator,
|
||||||
|
///
|
||||||
|
typename Enable = void
|
||||||
|
>
|
||||||
|
struct DefaultGemmStreamkWithBroadcast {
|
||||||
|
|
||||||
|
using GemmBase = typename DefaultGemmUniversal<
|
||||||
|
ElementA_, LayoutA_, TransformA, kAlignmentA,
|
||||||
|
ElementB_, LayoutB_, TransformB, kAlignmentB,
|
||||||
|
ElementC_, LayoutC_, ElementAccumulator,
|
||||||
|
OperatorClass,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
ThreadblockSwizzle,
|
||||||
|
Stages,
|
||||||
|
Operator
|
||||||
|
>::GemmKernel;
|
||||||
|
|
||||||
|
// Replace epilogue
|
||||||
|
using Epilogue = typename cutlass::epilogue::threadblock::DefaultStreamkEpilogueWithBroadcastTensorOp<
|
||||||
|
typename GemmBase::Epilogue::Shape,
|
||||||
|
typename GemmBase::Epilogue::WarpMmaOperator,
|
||||||
|
GemmBase::Epilogue::kPartitionsK,
|
||||||
|
ElementC_,
|
||||||
|
typename EpilogueOutputOp::ElementT,
|
||||||
|
typename EpilogueOutputOp::ElementVector,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
GemmBase::Epilogue::kElementsPerAccess
|
||||||
|
>::Epilogue;
|
||||||
|
|
||||||
|
// Compose the GEMM kernel
|
||||||
|
using GemmKernel = GemmStreamkWithFusedEpilogue<
|
||||||
|
typename GemmBase::Mma,
|
||||||
|
Epilogue,
|
||||||
|
ThreadblockSwizzle
|
||||||
|
>;
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
2405
include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h
Normal file
2405
include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user