Stream-K with broadcast (#892)

* [WIP] GEMM StreamK w/ Fused Epilogue

* Adds Gemm Streamk with Fused Epilogue kernel level struct.
  * Mostly based on Gemm with Fused Epilogue,
  * Requires a new epilogue
  * Work in progress

* [WIP] StreamK support for GemmUniversalWithBroadcast

* Just based off of how StreamK is allowed in GemmUniversal
  * Untested and a work in progress

* Minor fixes

* [WIP] It compiles!

It is almost certainly incorrect, but we're past getting the templates
to match, so checkpointing.

* Correction to reference kernel

* Fix typo

* Added MSE measurement

* Switch back to reference kernel + host for loop

Still WIP. Now we're getting even a larger MSE, but it's both on
basic Split-K and Stream-K.

* Fix typos

* Fix broadcast vector + requested changes

* Comment typo

* Small int option and more

* Fix incorrect condition on source needed

* Requested changes

* I think I got it?

* Bias vector should be stride 0

* Two source added!

* Typos

* Merge examples

* Bring back vector row offset

Just to ensure consistency with universal gemm with fused epilogue

* Base arguments and params structs for StreamK

* StreamK epilogue with broadcast now inherits the original

* undo params_streamk_base.h

---------

Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Ali Hassani 2023-05-22 16:05:06 -07:00 committed by GitHub
parent 6fbc0d3380
commit 13f413493a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 4285 additions and 3 deletions

View File

@ -33,3 +33,7 @@ cutlass_example_add_executable(
ampere_gemm_universal_streamk.cu ampere_gemm_universal_streamk.cu
) )
cutlass_example_add_executable(
47_ampere_gemm_universal_streamk_broadcast
ampere_gemm_universal_streamk_broadcast.cu
)

View File

@ -495,7 +495,7 @@ int main(int argc, const char **argv)
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
// Fill matrix A on host with uniform-random data [2, -2] // Fill matrix A on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform( cutlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(), options.tensor_a.host_view(),
1, 1,
@ -503,7 +503,7 @@ int main(int argc, const char **argv)
ElementA(-2), ElementA(-2),
0); 0);
// Fill matrix B on host with uniform-random data [2, -2] // Fill matrix B on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform( cutlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(), options.tensor_b.host_view(),
1, 1,
@ -511,7 +511,7 @@ int main(int argc, const char **argv)
ElementB(-2), ElementB(-2),
0); 0);
// Fill matrix C on host with uniform-random data [2, -2] // Fill matrix C on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform( cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c.host_view(), options.tensor_c.host_view(),
1, 1,

View File

@ -0,0 +1,658 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/***************************************************************************************************
Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
"classic data-parallel" and "Split-K" decompositions + residual add.
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598)
Requires NVIDIA Ampere or newer device (SM80+).
- To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100)
cutlass$ sudo nvidia-smi -pm 1 -i 0
cutlass$ sudo nvidia-smi -i 0 -pl 400
cutlass$ sudo nvidia-smi -i 0 -lgc 1005
- Build and run:
cutlass$ mkdir build
cutlass$ cd build
cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80
cutlass/build$ make 47_ampere_gemm_universal_streamk_broadcast
cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk_broadcast
- Reset clocks when done:
cutlass$ sudo nvidia-smi -rgc
**************************************************************************************************/
#include <iostream>
#include <string>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/error_metrics.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_foreach.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C1/C2/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = cutlass::half_t; // Element type for output matrix operands
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation
using ElementCompute = cutlass::half_t; // Element type for compute
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
// Residual block configuration
// Epilogue output operator
/// Using LinearCombinationResidualBlock
/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2))
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
ElementOutput, // Element type for output matrix
ElementAccumulator, // Element type from internal accumulation
ElementCompute, // Element type from internal accumulation
ElementC, // Element type for C1/C2/D matrix operands
AlignmentC, // Memory access granularity of C and D matrix in units of elements
cutlass::epilogue::thread::Identity, // Activation
cutlass::plus, // Binary operation 1
cutlass::epilogue::thread::Identity, // Unary operation
cutlass::plus // Binary operation 2
>;
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
// Classic data-parallel device GEMM implementation type
using DeviceGemmBasic = cutlass::gemm::device::GemmUniversalWithBroadcast<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
NumStages,
AlignmentA,
AlignmentB>;
// StreamK device GEMM implementation type
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalStreamkWithBroadcast<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
AlignmentA,
AlignmentB>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
/// Command line options parsing
struct Options
{
std::string command_name;
bool help;
cutlass::gemm::GemmCoord problem_size;
float alpha;
float beta;
int split_k_factor;
int avail_sms;
int iterations;
bool real;
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
cutlass::HostTensor<ElementC, LayoutC> tensor_c1;
cutlass::HostTensor<ElementC, LayoutC> tensor_c2;
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
cutlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
cutlass::HostTensor<ElementC, LayoutC> tensor_Vector;
// cutlass::HostTensor<ElementC, LayoutC> tensor_Tensor;
Options(std::string command_name) :
command_name(command_name),
help(false),
problem_size({2048, 2048, 2048}),
alpha(1.0f),
beta(1.0f),
split_k_factor(1),
avail_sms(-1), // Number of device SMs to use is unlimited
real(false),
iterations(10000)
{}
bool valid() const
{
return true;
}
void parse(int argc, char const **args)
{
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("split", split_k_factor);
cmd.get_cmd_line_argument("iterations", iterations);
real = cmd.check_cmd_line_flag("real");
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const
{
out
<< "Performs a GEMM computation.\n"
<< "\n"
<< "Options:\n"
<< "\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --split=<int> Split-K factor to emulate\n\n"
<< " --real If specified, initializes with real values instead of whole numbers. Errors are to be expected.\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options
typename DeviceGemmBasic::Arguments args_from_options(
const DeviceGemmBasic &device_gemm,
const Options &options,
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c1,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c2,
cutlass::HostTensor<ElementC, LayoutC> &tensor_d,
cutlass::HostTensor<ElementC, LayoutC> &tensor_Vector /*,
cutlass::HostTensor<ElementC, LayoutC> &tensor_Tensor */
)
{
return typename DeviceGemmBasic::Arguments(
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
options.problem_size, // problem_size
options.split_k_factor, // batch count / splitk slices
{ // epilogue parameters
ElementAccumulator(options.alpha),
ElementAccumulator(options.beta)
},
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
tensor_c1.device_data(), // ptr_C1
tensor_c2.device_data(), // ptr_C2
tensor_d.device_data(), // ptr_D
tensor_Vector.device_data(), // ptr_Vector
/* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor
options.problem_size.mk().product(), // batch_stride_A
options.problem_size.nk().product(), // batch_stride_B
options.problem_size.mn().product(), // batch_stride_C1
options.problem_size.mn().product(), // batch_stride_C2
options.problem_size.mn().product(), // batch_stride_D
options.problem_size.mn().product(), // batch_stride_Vector
options.problem_size.mn().product(), // batch_stride_Tensor
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
tensor_c1.layout().stride(0), // stride_c1
tensor_c2.layout().stride(0), // stride_c2
tensor_d.layout().stride(0), // stride_d
/*tensor_Vector.layout().stride(0)*/0, // stride_Vector
/*tensor_Tensor.layout().stride(0)*/0); // stride_Tensor
}
/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
typename DeviceGemmStreamK::Arguments args_from_options(
const DeviceGemmStreamK &device_gemm,
const Options &options,
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c1,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c2,
cutlass::HostTensor<ElementC, LayoutC> &tensor_d,
cutlass::HostTensor<ElementC, LayoutC> &tensor_Vector/*,
cutlass::HostTensor<ElementC, LayoutC> &tensor_Tensor*/
)
{
return typename DeviceGemmStreamK::Arguments(
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
options.problem_size, // problem_size
options.split_k_factor, // batch count / splitk slices
{ // epilogue parameters
ElementAccumulator(options.alpha),
ElementAccumulator(options.beta)
},
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
tensor_c1.device_data(), // ptr_C1
tensor_c2.device_data(), // ptr_C2
tensor_d.device_data(), // ptr_D
tensor_Vector.device_data(), // ptr_Vector
/* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor // We're not storing Tensor
options.problem_size.mk().product(), // batch_stride_A
options.problem_size.nk().product(), // batch_stride_B
options.problem_size.mn().product(), // batch_stride_C1
options.problem_size.mn().product(), // batch_stride_C2
options.problem_size.mn().product(), // batch_stride_D
options.problem_size.mn().product(), // batch_stride_Vector
options.problem_size.mn().product(), // batch_stride_Tensor
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
tensor_c1.layout().stride(0), // stride_c1
tensor_c2.layout().stride(0), // stride_c2
tensor_d.layout().stride(0), // stride_d
/*tensor_Vector.layout().stride(0)*/0, // stride_Vector // Vector stride is always 0
/*tensor_Tensor.layout().stride(0)*/0, // stride_Tensor // We're not storing Tensor
options.avail_sms); // avail_sms
}
/// Execute a given example GEMM computation
template <typename DeviceGemmT>
Result run(std::string description, Options &options)
{
// Display test description
std::cout << std::endl << description << std::endl;
// Zero-initialize test output matrix D
cutlass::reference::host::TensorFill(options.tensor_d.host_view());
options.tensor_d.sync_device();
// Instantiate CUTLASS kernel depending on templates
DeviceGemmT device_gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
auto arguments = args_from_options(device_gemm, options,
options.tensor_a, options.tensor_b, options.tensor_c1, options.tensor_c2, options.tensor_d,
options.tensor_Vector/*, options.tensor_Tensor*/);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
CUTLASS_CHECK(device_gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(device_gemm());
// Copy output data from CUTLASS and reference kernel to host for comparison
options.tensor_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = cutlass::reference::host::TensorEquals(
options.tensor_d.host_view(),
options.tensor_ref_d.host_view());
double err = cutlass::reference::host::TensorRelativeErrorMetric(
options.tensor_d.host_view(),
options.tensor_ref_d.host_view());
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl;
// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(device_gemm());
}
timer.stop();
// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
}
// TODO: uncomment when results match
//if (!result.passed) {
// exit(-1);
//}
return result;
}
/// Program entrypoint
int main(int argc, const char **argv)
{
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
// Current device must must have compute capability at least 80
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!((props.major * 10 + props.minor) >= 80))
{
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
// Parse commandline options
Options options("ampere_streamk_broadcast_gemm");
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
std::cout <<
options.iterations << " timing iterations of " <<
options.problem_size.m() << " x " <<
options.problem_size.n() << " x " <<
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
//
// Initialize GEMM datasets
//
// Initialize tensors using CUTLASS helper functions
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
options.tensor_c1.resize(options.problem_size.mn()); // <- Create matrix C1 with dimensions M x N
options.tensor_c2.resize(options.problem_size.mn()); // <- Create matrix C2 with dimensions M x N
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
options.tensor_Vector.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions N x 1
// options.tensor_Tensor.resize(options.problem_size.mn()); // <- Create T matrix with dimensions M x N
int _init_bits = options.real ? -1 : 0;
// Fill matrix A on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(),
1,
ElementA(2),
ElementA(-2), _init_bits);
// Fill matrix B on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(),
1,
ElementB(2),
ElementB(-2), _init_bits);
// Fill matrix C1 on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c1.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
// Fill matrix C2 on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c2.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_Vector.host_view(),
1,
ElementC(2),
ElementC(-2), _init_bits);
//
// Compute reference output
//
// Copy data from host to GPU
options.tensor_a.sync_device();
options.tensor_b.sync_device();
options.tensor_c1.sync_device();
options.tensor_c2.sync_device();
options.tensor_Vector.sync_device();
// options.tensor_Tensor.sync_device();
// Zero-initialize reference output matrix D
cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
options.tensor_ref_d.sync_device();
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
options.problem_size,
ElementAccumulator(options.alpha),
options.tensor_a.device_ref(),
options.tensor_b.device_ref(),
ElementAccumulator(options.beta),
options.tensor_c1.device_ref(),
options.tensor_ref_d.device_ref());
// Wait for kernels to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Copy output data from reference kernel to host for comparison
options.tensor_ref_d.sync_host();
// Add broadcast vector (without multiplier)
// This is only possible because BinaryOp is addition, and UnaryOps are identity.
// This makes the addition of broadcast vector commutable.
/// identity(plus(identity(alpha * (a * b) + v), beta * c)) ==
/// alpha * a * b + v + beta * c ==
/// (alpha * a * b + beta * c) + v ==
/// GEMM(a, b, c) + v
// Vector broadcast on host
for (int i=0; i < options.problem_size.m(); ++i) {
for (int j=0; j < options.problem_size.n(); ++j) {
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_Vector.host_view().ref().at({0, j});
options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_c2.host_view().ref().at({i, j});
}
}
// Sync back with device just in case
options.tensor_ref_d.sync_device();
//
// Evaluate CUTLASS kernels
//
// Test default operation
if (options.split_k_factor == 1)
{
// Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
Result basic_dp = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
options.split_k_factor++; // Increment splitting factor for next evaluation
}
// Show that StreamK can emulate "Split-K" with a tile-splitting factor
Result basic_splitk = run<DeviceGemmBasic>(
std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
options);
Result streamk_splitk = run<DeviceGemmStreamK>(
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
options);
printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
return 0;
}

View File

@ -48,6 +48,7 @@
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/epilogue.h" #include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" #include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
#include "cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h"
#include "cutlass/layout/permute.h" #include "cutlass/layout/permute.h"
@ -120,6 +121,67 @@ struct DefaultEpilogueWithBroadcastTensorOp {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for streamk epilogues for TensorOps.
template <
typename Shape,
typename WarpMmaTensorOp,
int PartitionsK,
typename ElementOutput,
typename ElementTensor,
typename ElementVector,
typename OutputOp,
int ElementsPerAccess,
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultStreamkEpilogueWithBroadcastTensorOp {
/// Use defaults related to the existing epilogue
using Base = DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputOp,
ElementsPerAccess
>;
//
// Stores the result z = (y = GEMM(A, B, C), broadcast)
//
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
typename Base::OutputTileThreadMap,
ElementOutput,
ScatterD,
PermuteDLayout
>;
//
// Additional tensor tile iterator - stores t = Elementwise(z)
//
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
typename Base::OutputTileThreadMap,
ElementTensor
>;
/// Define the epilogue
using Epilogue = EpilogueStreamkWithBroadcast<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputTileIterator,
TensorTileIterator,
ElementVector,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
OutputOp,
typename Base::Padding,
Base::kFragmentsPerIteration
>;
};
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for VoltaTensorOps. /// Defines sensible defaults for epilogues for VoltaTensorOps.
template < template <
typename Shape, typename Shape,

View File

@ -0,0 +1,443 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#include <cuda/std/utility>
#else
#include <assert.h>
#include <utility>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/functional.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/vector.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// This base class is meant to define the concept required of the
/// EpilogueStreamkWithBroadcast::OutputOp
template <
typename ElementC_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementZ_,
typename ElementT_,
int ElementsPerAccess,
bool StoreZ = true,
bool StoreT = true
>
struct EpilogueStreamkWithBroadcastOpBase : EpilogueWithBroadcastOpBase<
ElementC_,
ElementAccumulator_,
ElementCompute_,
ElementZ_,
ElementT_,
ElementsPerAccess,
StoreZ,
StoreT
>
{
/// Parameters structure - required
struct Params { };
//
// Methods
//
/// Constructor from Params
EpilogueStreamkWithBroadcastOpBase(Params const &params_) { }
};
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator with bias vector broadcast over columns.
///
/// Computes the following:
///
///
/// Z, T = OutputOp(AB, C, Broadcast)
///
/// if (ElementwiseOp::kStoreZ) {
/// store(converted_u);
/// }
///
/// if (ElementwiseOp::kStoreT) {
/// store(v);
/// }
///
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z)
typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t)
typename ElementVector_, ///< Pointer to broadcast vector
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM
typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp
typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape)
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
bool IsSingleSource = OutputOp_::kIsSingleSource
>
class EpilogueStreamkWithBroadcast;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EpilogueStreamkWithBroadcast: Two sources
template <
typename Shape_,
typename WarpMmaOperator_,
int PartitionsK,
typename OutputTileIterator_,
typename TensorTileIterator_,
typename ElementVector_,
typename AccumulatorFragmentIterator_,
typename WarpTileIterator_,
typename SharedLoadIterator_,
typename OutputOp_,
typename Padding_,
int FragmentsPerPartition,
int IterationsUnroll
>
class EpilogueStreamkWithBroadcast<
Shape_,
WarpMmaOperator_,
PartitionsK,
OutputTileIterator_,
TensorTileIterator_,
ElementVector_,
AccumulatorFragmentIterator_,
WarpTileIterator_,
SharedLoadIterator_,
OutputOp_,
Padding_,
FragmentsPerPartition,
IterationsUnroll,
false
> :
public EpilogueWithBroadcast<
Shape_,
WarpMmaOperator_,
PartitionsK,
OutputTileIterator_,
TensorTileIterator_,
ElementVector_,
AccumulatorFragmentIterator_,
WarpTileIterator_,
SharedLoadIterator_,
OutputOp_,
Padding_,
FragmentsPerPartition,
IterationsUnroll,
false>,
public EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>
{
public:
using Base = EpilogueWithBroadcast<
Shape_,
WarpMmaOperator_,
PartitionsK,
OutputTileIterator_,
TensorTileIterator_,
ElementVector_,
AccumulatorFragmentIterator_,
WarpTileIterator_,
SharedLoadIterator_,
OutputOp_,
Padding_,
FragmentsPerPartition,
IterationsUnroll,
false>;
using BaseStreamK = EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>;
using Shape = Shape_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using TensorTileIterator = TensorTileIterator_;
using ElementVector = ElementVector_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
/// Fragment type used by the accumulator tile's fragment iterator
using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment;
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
using SharedStorage = typename Base::SharedStorage;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueStreamkWithBroadcast(
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
BaseStreamK(thread_idx)
{ }
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
/// performing epilogue computations, writing to output
CUTLASS_DEVICE
void reduce(
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
void *element_workspace,
OutputOp const &output_op, ///< Output operator
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
MatrixCoord())
{
// Reduce peer accumulator fragments into one fragment
AccumulatorFragment accum_fragment;
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
// Store fragment to shared memory
this->warp_tile_iterator_.store(accum_fragment);
__syncthreads();
Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator1, source_iterator2, tensor_iterator, problem_size, threadblock_offset);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EpilogueStreamkWithBroadcast: Single source
template <
typename Shape_,
typename WarpMmaOperator_,
int PartitionsK,
typename OutputTileIterator_,
typename TensorTileIterator_,
typename ElementVector_,
typename AccumulatorFragmentIterator_,
typename WarpTileIterator_,
typename SharedLoadIterator_,
typename OutputOp_,
typename Padding_,
int FragmentsPerPartition,
int IterationsUnroll
>
class EpilogueStreamkWithBroadcast<
Shape_,
WarpMmaOperator_,
PartitionsK,
OutputTileIterator_,
TensorTileIterator_,
ElementVector_,
AccumulatorFragmentIterator_,
WarpTileIterator_,
SharedLoadIterator_,
OutputOp_,
Padding_,
FragmentsPerPartition,
IterationsUnroll,
true
> :
public EpilogueWithBroadcast<
Shape_,
WarpMmaOperator_,
PartitionsK,
OutputTileIterator_,
TensorTileIterator_,
ElementVector_,
AccumulatorFragmentIterator_,
WarpTileIterator_,
SharedLoadIterator_,
OutputOp_,
Padding_,
FragmentsPerPartition,
IterationsUnroll,
true>,
public EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>
{
public:
using Base = EpilogueWithBroadcast<
Shape_,
WarpMmaOperator_,
PartitionsK,
OutputTileIterator_,
TensorTileIterator_,
ElementVector_,
AccumulatorFragmentIterator_,
WarpTileIterator_,
SharedLoadIterator_,
OutputOp_,
Padding_,
FragmentsPerPartition,
IterationsUnroll,
true>;
using BaseStreamK = EpilogueBaseStreamK<
Shape_,
PartitionsK,
WarpMmaOperator_,
AccumulatorFragmentIterator_>;
using Shape = Shape_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using TensorTileIterator = TensorTileIterator_;
using ElementVector = ElementVector_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
/// Fragment type used by the accumulator tile's fragment iterator
using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment;
/// Shared storage structure (shadows base) with additional SMEM buffer for reduction
using SharedStorage = typename Base::SharedStorage;
public:
/// Constructor
CUTLASS_DEVICE
EpilogueStreamkWithBroadcast(
SharedStorage &shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
BaseStreamK(thread_idx)
{ }
/// Aggregates the accumulator sets shared by peer blocks in the global workspace,
/// performing epilogue computations, writing to output
CUTLASS_DEVICE
void reduce(
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
void *element_workspace,
OutputOp const &output_op, ///< Output operator
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
MatrixCoord())
{
// Reduce peer accumulator fragments into one fragment
AccumulatorFragment accum_fragment;
BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace);
// Store fragment to shared memory
this->warp_tile_iterator_.store(accum_fragment);
__syncthreads();
Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator, tensor_iterator, problem_size, threadblock_offset);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -863,6 +863,98 @@ private:
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
} }
} }
public:
/// Stream-K reduce helper
CUTLASS_DEVICE
void reduce(
int reduce_fragment_idx, ///< Reduce fragment index
OutputOp const &output_op, ///< Output operator
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix
OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
MatrixCoord())
{
BroadcastFragment broadcast_fragment;
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
// Initialize/load source-fragment data
typename OutputTileIterator::Fragment source_fragment1;
source_fragment1.clear();
typename OutputTileIterator::Fragment source_fragment2;
source_fragment2.clear();
if (output_op.is_source_needed())
{
source_iterator1 += reduce_fragment_idx;
source_iterator1.load(source_fragment1);
source_iterator2 += reduce_fragment_idx;
source_iterator2.load(source_fragment2);
}
// Load fragment from shared memory
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// Add fragments shared by other k partitions
if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
}
//
// Apply output operation
//
typename OutputTileIterator::Fragment frag_Z;
typename TensorTileIterator::Fragment frag_T;
if (!output_op.is_source_needed()) {
apply_output_operator_source_not_needed_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
broadcast_fragment);
} else {
apply_output_operator_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
source_fragment1,
source_fragment2,
broadcast_fragment);
}
//
// Conditionally store fragments
//
if (OutputOp::kStoreZ) {
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
}; };
@ -1529,6 +1621,92 @@ private:
frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]);
} }
} }
public:
/// Stream-K reduce helper
CUTLASS_DEVICE
void reduce(
int reduce_fragment_idx, ///< Reduce fragment index
OutputOp const &output_op, ///< Output operator
ElementVector const * broadcast_ptr, ///< Broadcast vector
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand
MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses
MatrixCoord(Shape::kM, Shape::kN),
MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space
MatrixCoord())
{
BroadcastFragment broadcast_fragment;
load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset);
// Initialize/load source-fragment data
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
if (output_op.is_source_needed())
{
source_iterator += reduce_fragment_idx;
source_iterator.load(source_fragment);
}
// Load fragment from shared memory
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// Add fragments shared by other k partitions
if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
}
//
// Apply output operation
//
typename OutputTileIterator::Fragment frag_Z;
typename TensorTileIterator::Fragment frag_T;
if (!output_op.is_source_needed()) {
apply_output_operator_source_not_needed_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
broadcast_fragment);
} else {
apply_output_operator_(
frag_Z,
frag_T,
output_op,
aligned_accum_fragment[0],
source_fragment,
broadcast_fragment);
}
//
// Conditionally store fragments
//
if (OutputOp::kStoreZ) {
destination_iterator.store(frag_Z);
++destination_iterator;
}
if (OutputOp::kStoreT) {
tensor_iterator.store(frag_T);
++tensor_iterator;
}
}
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,386 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a Stream-K GEMM kernel that can broadcast bias vector in the
epilogue.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/device/gemm_universal_base.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*!
The universal GEMM with a broadcast epilogue.
Supports
*/
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
ElementC_, ElementAccumulator_, ElementAccumulator_,
ElementC_, ElementC_, 128 / cutlass::sizeof_bits<ElementC_>::value>,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA = ComplexTransform::kNone,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB = ComplexTransform::kNone
>
class GemmUniversalStreamkWithBroadcast :
public GemmUniversalBase<
typename kernel::DefaultGemmStreamkWithBroadcast<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_
>::GemmKernel
> {
public:
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp::kCount;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using Base = GemmUniversalBase<
typename kernel::DefaultGemmStreamkWithBroadcast<
ElementA_,
LayoutA_,
TransformA,
AlignmentA,
ElementB_,
LayoutB_,
TransformB,
AlignmentB,
ElementC_,
LayoutC_,
ElementAccumulator_,
OperatorClass_,
ArchTag_,
ThreadblockShape_,
WarpShape_,
InstructionShape_,
EpilogueOutputOp_,
ThreadblockSwizzle_,
Stages,
Operator_
>::GemmKernel
>;
using Arguments = typename Base::Arguments;
using GemmKernel = typename Base::GemmKernel;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for column-major output exchanges problem size and operand.
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for. This is the minimum SM that
/// supports the intended feature. The device kernel can be built
/// targeting any SM larger than this number.
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Epilogue output operator
typename EpilogueOutputOp_,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Access granularity of A matrix in units of elements
int AlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB,
/// Operation performed by GEMM
typename Operator_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB>
class GemmUniversalStreamkWithBroadcast<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
layout::ColumnMajor, // partially specialized on LayoutC
ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
WarpShape_, InstructionShape_, EpilogueOutputOp_,
ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB,
Operator_, TransformA, TransformB> {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = layout::ColumnMajor;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp = EpilogueOutputOp_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static ComplexTransform const kTransformA = TransformA;
static ComplexTransform const kTransformB = TransformB;
using UnderlyingOperator = typename GemmUniversalStreamkWithBroadcast<
ElementB,
typename layout::LayoutTranspose<LayoutB>::type,
ElementA,
typename layout::LayoutTranspose<LayoutA>::type,
ElementC,
layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
kAlignmentB,
kAlignmentA,
Operator,
kTransformB,
kTransformA
>::Base;
using GemmKernel = typename UnderlyingOperator::GemmKernel;
static int const kAlignmentC = EpilogueOutputOp::kCount;
/// Argument structure
using Arguments = typename UnderlyingOperator::Arguments;
private:
UnderlyingOperator underlying_operator_;
public:
/// Constructs the GEMM.
GemmUniversalStreamkWithBroadcast() { }
/// Helper to construct a transposed equivalent for the underying GEMM operator
static Arguments to_underlying_arguments(Arguments const &args) {
return args.transposed_problem();
}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
return UnderlyingOperator::can_implement(to_underlying_arguments(args));
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args));
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
return UnderlyingOperator::maximum_active_blocks(smem_capacity);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
return underlying_operator_.update(to_underlying_arguments(args), workspace);
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
return underlying_operator_.run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,146 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Defines a Stream-K GEMM that can broadcast a bias vector in the epilogue.
Similar structure to DefaultGemmWithBroadcast, but uses its own epilogue
(DefaultStreamkEpilogueWithBroadcastTensorOp) and its own GEMM kernel
(GemmStreamkWithFusedEpilogue).
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h"
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp'
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
///
typename Enable = void
>
struct DefaultGemmStreamkWithBroadcast {
using GemmBase = typename DefaultGemmUniversal<
ElementA_, LayoutA_, TransformA, kAlignmentA,
ElementB_, LayoutB_, TransformB, kAlignmentB,
ElementC_, LayoutC_, ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
Stages,
Operator
>::GemmKernel;
// Replace epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultStreamkEpilogueWithBroadcastTensorOp<
typename GemmBase::Epilogue::Shape,
typename GemmBase::Epilogue::WarpMmaOperator,
GemmBase::Epilogue::kPartitionsK,
ElementC_,
typename EpilogueOutputOp::ElementT,
typename EpilogueOutputOp::ElementVector,
EpilogueOutputOp,
GemmBase::Epilogue::kElementsPerAccess
>::Epilogue;
// Compose the GEMM kernel
using GemmKernel = GemmStreamkWithFusedEpilogue<
typename GemmBase::Mma,
Epilogue,
ThreadblockSwizzle
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff