diff --git a/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt b/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt index 910f4844..783cbf84 100644 --- a/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt +++ b/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt @@ -33,3 +33,7 @@ cutlass_example_add_executable( ampere_gemm_universal_streamk.cu ) +cutlass_example_add_executable( + 47_ampere_gemm_universal_streamk_broadcast + ampere_gemm_universal_streamk_broadcast.cu + ) diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu index bb995f59..f99cbd39 100644 --- a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu +++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu @@ -495,7 +495,7 @@ int main(int argc, const char **argv) options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel - // Fill matrix A on host with uniform-random data [2, -2] + // Fill matrix A on host with uniform-random data [-2, 2] cutlass::reference::host::TensorFillRandomUniform( options.tensor_a.host_view(), 1, @@ -503,7 +503,7 @@ int main(int argc, const char **argv) ElementA(-2), 0); - // Fill matrix B on host with uniform-random data [2, -2] + // Fill matrix B on host with uniform-random data [-2, 2] cutlass::reference::host::TensorFillRandomUniform( options.tensor_b.host_view(), 1, @@ -511,7 +511,7 @@ int main(int argc, const char **argv) ElementB(-2), 0); - // Fill matrix C on host with uniform-random data [2, -2] + // Fill matrix C on host with uniform-random data [-2, 2] cutlass::reference::host::TensorFillRandomUniform( options.tensor_c.host_view(), 1, diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu new file mode 100644 index 00000000..db2eff51 --- /dev/null +++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @@ -0,0 +1,658 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*************************************************************************************************** + Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the + "classic data-parallel" and "Split-K" decompositions + residual add. + + For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition + for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598) + + Requires NVIDIA Ampere or newer device (SM80+). + + - To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100) + + cutlass$ sudo nvidia-smi -pm 1 -i 0 + + cutlass$ sudo nvidia-smi -i 0 -pl 400 + + cutlass$ sudo nvidia-smi -i 0 -lgc 1005 + + - Build and run: + + cutlass$ mkdir build + + cutlass$ cd build + + cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80 + + cutlass/build$ make 47_ampere_gemm_universal_streamk_broadcast + + cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk_broadcast + + - Reset clocks when done: + + cutlass$ sudo nvidia-smi -rgc + + **************************************************************************************************/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" +#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_foreach.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8) +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C1/C2/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes) + +// Output matrix configuration +using ElementOutput = cutlass::half_t; // Element type for output matrix operands +using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands +// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes) + +// Multiply-accumulate blocking/pipelining details +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ElementCompute = cutlass::half_t; // Element type for compute +using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape) +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape) +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape) +constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop + +// Residual block configuration + +// Epilogue output operator +/// Using LinearCombinationResidualBlock +/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2)) +using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementOutput, // Element type for output matrix + ElementAccumulator, // Element type from internal accumulation + ElementCompute, // Element type from internal accumulation + ElementC, // Element type for C1/C2/D matrix operands + AlignmentC, // Memory access granularity of C and D matrix in units of elements + cutlass::epilogue::thread::Identity, // Activation + cutlass::plus, // Binary operation 1 + cutlass::epilogue::thread::Identity, // Unary operation + cutlass::plus // Binary operation 2 + >; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +// Classic data-parallel device GEMM implementation type +using DeviceGemmBasic = cutlass::gemm::device::GemmUniversalWithBroadcast< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + NumStages, + AlignmentA, + AlignmentB>; + +// StreamK device GEMM implementation type +using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalStreamkWithBroadcast< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + NumStages, + AlignmentA, + AlignmentB>; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true) + {} + +}; + + +/// Command line options parsing +struct Options +{ + std::string command_name; + bool help; + cutlass::gemm::GemmCoord problem_size; + float alpha; + float beta; + int split_k_factor; + int avail_sms; + int iterations; + bool real; + + cutlass::HostTensor tensor_a; + cutlass::HostTensor tensor_b; + cutlass::HostTensor tensor_c1; + cutlass::HostTensor tensor_c2; + cutlass::HostTensor tensor_d; + cutlass::HostTensor tensor_ref_d; + cutlass::HostTensor tensor_Vector; + // cutlass::HostTensor tensor_Tensor; + + Options(std::string command_name) : + command_name(command_name), + help(false), + problem_size({2048, 2048, 2048}), + alpha(1.0f), + beta(1.0f), + split_k_factor(1), + avail_sms(-1), // Number of device SMs to use is unlimited + real(false), + iterations(10000) + {} + + bool valid() const + { + return true; + } + + void parse(int argc, char const **args) + { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("split", split_k_factor); + cmd.get_cmd_line_argument("iterations", iterations); + real = cmd.check_cmd_line_flag("real"); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const + { + out + << "Performs a GEMM computation.\n" + << "\n" + << "Options:\n" + << "\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m= GEMM M dimension\n" + << " --n= GEMM N dimension\n" + << " --k= GEMM K dimension\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --split= Split-K factor to emulate\n\n" + << " --real If specified, initializes with real values instead of whole numbers. Errors are to be expected.\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options +typename DeviceGemmBasic::Arguments args_from_options( + const DeviceGemmBasic &device_gemm, + const Options &options, + cutlass::HostTensor &tensor_a, + cutlass::HostTensor &tensor_b, + cutlass::HostTensor &tensor_c1, + cutlass::HostTensor &tensor_c2, + cutlass::HostTensor &tensor_d, + cutlass::HostTensor &tensor_Vector /*, + cutlass::HostTensor &tensor_Tensor */ + ) +{ + return typename DeviceGemmBasic::Arguments( + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + options.problem_size, // problem_size + options.split_k_factor, // batch count / splitk slices + { // epilogue parameters + ElementAccumulator(options.alpha), + ElementAccumulator(options.beta) + }, + tensor_a.device_data(), // ptr_A + tensor_b.device_data(), // ptr_B + tensor_c1.device_data(), // ptr_C1 + tensor_c2.device_data(), // ptr_C2 + tensor_d.device_data(), // ptr_D + tensor_Vector.device_data(), // ptr_Vector + /* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor + options.problem_size.mk().product(), // batch_stride_A + options.problem_size.nk().product(), // batch_stride_B + options.problem_size.mn().product(), // batch_stride_C1 + options.problem_size.mn().product(), // batch_stride_C2 + options.problem_size.mn().product(), // batch_stride_D + options.problem_size.mn().product(), // batch_stride_Vector + options.problem_size.mn().product(), // batch_stride_Tensor + tensor_a.layout().stride(0), // stride_a + tensor_b.layout().stride(0), // stride_b + tensor_c1.layout().stride(0), // stride_c1 + tensor_c2.layout().stride(0), // stride_c2 + tensor_d.layout().stride(0), // stride_d + /*tensor_Vector.layout().stride(0)*/0, // stride_Vector + /*tensor_Tensor.layout().stride(0)*/0); // stride_Tensor +} + +/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options +typename DeviceGemmStreamK::Arguments args_from_options( + const DeviceGemmStreamK &device_gemm, + const Options &options, + cutlass::HostTensor &tensor_a, + cutlass::HostTensor &tensor_b, + cutlass::HostTensor &tensor_c1, + cutlass::HostTensor &tensor_c2, + cutlass::HostTensor &tensor_d, + cutlass::HostTensor &tensor_Vector/*, + cutlass::HostTensor &tensor_Tensor*/ + ) +{ + return typename DeviceGemmStreamK::Arguments( + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + options.problem_size, // problem_size + options.split_k_factor, // batch count / splitk slices + { // epilogue parameters + ElementAccumulator(options.alpha), + ElementAccumulator(options.beta) + }, + tensor_a.device_data(), // ptr_A + tensor_b.device_data(), // ptr_B + tensor_c1.device_data(), // ptr_C1 + tensor_c2.device_data(), // ptr_C2 + tensor_d.device_data(), // ptr_D + tensor_Vector.device_data(), // ptr_Vector + /* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor // We're not storing Tensor + options.problem_size.mk().product(), // batch_stride_A + options.problem_size.nk().product(), // batch_stride_B + options.problem_size.mn().product(), // batch_stride_C1 + options.problem_size.mn().product(), // batch_stride_C2 + options.problem_size.mn().product(), // batch_stride_D + options.problem_size.mn().product(), // batch_stride_Vector + options.problem_size.mn().product(), // batch_stride_Tensor + tensor_a.layout().stride(0), // stride_a + tensor_b.layout().stride(0), // stride_b + tensor_c1.layout().stride(0), // stride_c1 + tensor_c2.layout().stride(0), // stride_c2 + tensor_d.layout().stride(0), // stride_d + /*tensor_Vector.layout().stride(0)*/0, // stride_Vector // Vector stride is always 0 + /*tensor_Tensor.layout().stride(0)*/0, // stride_Tensor // We're not storing Tensor + options.avail_sms); // avail_sms +} + +/// Execute a given example GEMM computation +template +Result run(std::string description, Options &options) +{ + // Display test description + std::cout << std::endl << description << std::endl; + + // Zero-initialize test output matrix D + cutlass::reference::host::TensorFill(options.tensor_d.host_view()); + options.tensor_d.sync_device(); + + // Instantiate CUTLASS kernel depending on templates + DeviceGemmT device_gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT + auto arguments = args_from_options(device_gemm, options, + options.tensor_a, options.tensor_b, options.tensor_c1, options.tensor_c2, options.tensor_d, + options.tensor_Vector/*, options.tensor_Tensor*/); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = DeviceGemmT::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + CUTLASS_CHECK(device_gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(device_gemm()); + + // Copy output data from CUTLASS and reference kernel to host for comparison + options.tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = cutlass::reference::host::TensorEquals( + options.tensor_d.host_view(), + options.tensor_ref_d.host_view()); + + double err = cutlass::reference::host::TensorRelativeErrorMetric( + options.tensor_d.host_view(), + options.tensor_ref_d.host_view()); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl; + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(device_gemm()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + } + + // TODO: uncomment when results match + //if (!result.passed) { + // exit(-1); + //} + + return result; +} + + +/// Program entrypoint +int main(int argc, const char **argv) +{ + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + // Current device must must have compute capability at least 80 + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + if (!((props.major * 10 + props.minor) >= 80)) + { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + // Parse commandline options + Options options("ampere_streamk_broadcast_gemm"); + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + std::cout << + options.iterations << " timing iterations of " << + options.problem_size.m() << " x " << + options.problem_size.n() << " x " << + options.problem_size.k() << " matrix-matrix multiply" << std::endl; + + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + + // + // Initialize GEMM datasets + // + + // Initialize tensors using CUTLASS helper functions + options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K + options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N + options.tensor_c1.resize(options.problem_size.mn()); // <- Create matrix C1 with dimensions M x N + options.tensor_c2.resize(options.problem_size.mn()); // <- Create matrix C2 with dimensions M x N + options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel + options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel + options.tensor_Vector.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions N x 1 + // options.tensor_Tensor.resize(options.problem_size.mn()); // <- Create T matrix with dimensions M x N + + int _init_bits = options.real ? -1 : 0; + + // Fill matrix A on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_a.host_view(), + 1, + ElementA(2), + ElementA(-2), _init_bits); + + // Fill matrix B on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_b.host_view(), + 1, + ElementB(2), + ElementB(-2), _init_bits); + + // Fill matrix C1 on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_c1.host_view(), + 1, + ElementC(2), + ElementC(-2), _init_bits); + + // Fill matrix C2 on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_c2.host_view(), + 1, + ElementC(2), + ElementC(-2), _init_bits); + + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_Vector.host_view(), + 1, + ElementC(2), + ElementC(-2), _init_bits); + + // + // Compute reference output + // + + // Copy data from host to GPU + options.tensor_a.sync_device(); + options.tensor_b.sync_device(); + options.tensor_c1.sync_device(); + options.tensor_c2.sync_device(); + options.tensor_Vector.sync_device(); + // options.tensor_Tensor.sync_device(); + + // Zero-initialize reference output matrix D + cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view()); + options.tensor_ref_d.sync_device(); + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + options.problem_size, + ElementAccumulator(options.alpha), + options.tensor_a.device_ref(), + options.tensor_b.device_ref(), + ElementAccumulator(options.beta), + options.tensor_c1.device_ref(), + options.tensor_ref_d.device_ref()); + + // Wait for kernels to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Copy output data from reference kernel to host for comparison + options.tensor_ref_d.sync_host(); + + // Add broadcast vector (without multiplier) + // This is only possible because BinaryOp is addition, and UnaryOps are identity. + // This makes the addition of broadcast vector commutable. + /// identity(plus(identity(alpha * (a * b) + v), beta * c)) == + /// alpha * a * b + v + beta * c == + /// (alpha * a * b + beta * c) + v == + /// GEMM(a, b, c) + v + // Vector broadcast on host + for (int i=0; i < options.problem_size.m(); ++i) { + for (int j=0; j < options.problem_size.n(); ++j) { + options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_Vector.host_view().ref().at({0, j}); + options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_c2.host_view().ref().at({i, j}); + } + } + + // Sync back with device just in case + options.tensor_ref_d.sync_device(); + + // + // Evaluate CUTLASS kernels + // + + // Test default operation + if (options.split_k_factor == 1) + { + // Compare basic data-parallel version versus StreamK version using default load-balancing heuristics + Result basic_dp = run("Basic data-parallel GEMM", options); + Result streamk_default = run("StreamK GEMM with default load-balancing", options); + + printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms)); + + // Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1 + options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing) + Result streamk_dp = run("StreamK emulating basic data-parallel GEMM", options); + options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs) + + printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms)); + + options.split_k_factor++; // Increment splitting factor for next evaluation + + } + + // Show that StreamK can emulate "Split-K" with a tile-splitting factor + Result basic_splitk = run( + std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), + options); + + Result streamk_splitk = run( + std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), + options); + + printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms)); + + return 0; +} diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h index 381cb30d..9f6a0894 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h @@ -48,6 +48,7 @@ #include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" #include "cutlass/epilogue/threadblock/epilogue.h" #include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h" #include "cutlass/layout/permute.h" @@ -120,6 +121,67 @@ struct DefaultEpilogueWithBroadcastTensorOp { //////////////////////////////////////////////////////////////////////////////// +/// Defines sensible defaults for streamk epilogues for TensorOps. +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess, + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute +> +struct DefaultStreamkEpilogueWithBroadcastTensorOp { + + /// Use defaults related to the existing epilogue + using Base = DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + ElementsPerAccess + >; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementOutput, + ScatterD, + PermuteDLayout + >; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + typename Base::OutputTileThreadMap, + ElementTensor + >; + + /// Define the epilogue + using Epilogue = EpilogueStreamkWithBroadcast< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputTileIterator, + TensorTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding, + Base::kFragmentsPerIteration + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Defines sensible defaults for epilogues for VoltaTensorOps. template < typename Shape, diff --git a/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h new file mode 100644 index 00000000..54f822fe --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h @@ -0,0 +1,443 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#include +#else +#include +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueStreamkWithBroadcast::OutputOp +template < + typename ElementC_, + typename ElementAccumulator_, + typename ElementCompute_, + typename ElementZ_, + typename ElementT_, + int ElementsPerAccess, + bool StoreZ = true, + bool StoreT = true +> +struct EpilogueStreamkWithBroadcastOpBase : EpilogueWithBroadcastOpBase< + ElementC_, + ElementAccumulator_, + ElementCompute_, + ElementZ_, + ElementT_, + ElementsPerAccess, + StoreZ, + StoreT + > +{ + + /// Parameters structure - required + struct Params { }; + + // + // Methods + // + + /// Constructor from Params + EpilogueStreamkWithBroadcastOpBase(Params const ¶ms_) { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator with bias vector broadcast over columns. +/// +/// Computes the following: +/// +/// +/// Z, T = OutputOp(AB, C, Broadcast) +/// +/// if (ElementwiseOp::kStoreZ) { +/// store(converted_u); +/// } +/// +/// if (ElementwiseOp::kStoreT) { +/// store(v); +/// } +/// +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) + typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) + typename ElementVector_, ///< Pointer to broadcast vector + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value), + bool IsSingleSource = OutputOp_::kIsSingleSource +> +class EpilogueStreamkWithBroadcast; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// EpilogueStreamkWithBroadcast: Two sources + +template < + typename Shape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputTileIterator_, + typename TensorTileIterator_, + typename ElementVector_, + typename AccumulatorFragmentIterator_, + typename WarpTileIterator_, + typename SharedLoadIterator_, + typename OutputOp_, + typename Padding_, + int FragmentsPerPartition, + int IterationsUnroll +> +class EpilogueStreamkWithBroadcast< + Shape_, + WarpMmaOperator_, + PartitionsK, + OutputTileIterator_, + TensorTileIterator_, + ElementVector_, + AccumulatorFragmentIterator_, + WarpTileIterator_, + SharedLoadIterator_, + OutputOp_, + Padding_, + FragmentsPerPartition, + IterationsUnroll, + false +> : + public EpilogueWithBroadcast< + Shape_, + WarpMmaOperator_, + PartitionsK, + OutputTileIterator_, + TensorTileIterator_, + ElementVector_, + AccumulatorFragmentIterator_, + WarpTileIterator_, + SharedLoadIterator_, + OutputOp_, + Padding_, + FragmentsPerPartition, + IterationsUnroll, + false>, + public EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_> +{ + +public: + + using Base = EpilogueWithBroadcast< + Shape_, + WarpMmaOperator_, + PartitionsK, + OutputTileIterator_, + TensorTileIterator_, + ElementVector_, + AccumulatorFragmentIterator_, + WarpTileIterator_, + SharedLoadIterator_, + OutputOp_, + Padding_, + FragmentsPerPartition, + IterationsUnroll, + false>; + + using BaseStreamK = EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_>; + + using Shape = Shape_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + + /// Fragment type used by the accumulator tile's fragment iterator + using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment; + + /// Shared storage structure (shadows base) with additional SMEM buffer for reduction + using SharedStorage = typename Base::SharedStorage; + +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueStreamkWithBroadcast( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + BaseStreamK(thread_idx) + { } + + + /// Aggregates the accumulator sets shared by peer blocks in the global workspace, + /// performing epilogue computations, writing to output + CUTLASS_DEVICE + void reduce( + int peer_idx_begin, + int peer_idx_end, + int reduce_fragment_idx, + void *element_workspace, + OutputOp const &output_op, ///< Output operator + ElementVector const * broadcast_ptr, ///< Broadcast vector + OutputTileIterator destination_iterator, ///< Tile iterator for destination + OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix + OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix + TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) + { + // Reduce peer accumulator fragments into one fragment + AccumulatorFragment accum_fragment; + BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); + + // Store fragment to shared memory + this->warp_tile_iterator_.store(accum_fragment); + + __syncthreads(); + + Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator1, source_iterator2, tensor_iterator, problem_size, threadblock_offset); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// EpilogueStreamkWithBroadcast: Single source + +template < + typename Shape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputTileIterator_, + typename TensorTileIterator_, + typename ElementVector_, + typename AccumulatorFragmentIterator_, + typename WarpTileIterator_, + typename SharedLoadIterator_, + typename OutputOp_, + typename Padding_, + int FragmentsPerPartition, + int IterationsUnroll +> +class EpilogueStreamkWithBroadcast< + Shape_, + WarpMmaOperator_, + PartitionsK, + OutputTileIterator_, + TensorTileIterator_, + ElementVector_, + AccumulatorFragmentIterator_, + WarpTileIterator_, + SharedLoadIterator_, + OutputOp_, + Padding_, + FragmentsPerPartition, + IterationsUnroll, + true +> : + public EpilogueWithBroadcast< + Shape_, + WarpMmaOperator_, + PartitionsK, + OutputTileIterator_, + TensorTileIterator_, + ElementVector_, + AccumulatorFragmentIterator_, + WarpTileIterator_, + SharedLoadIterator_, + OutputOp_, + Padding_, + FragmentsPerPartition, + IterationsUnroll, + true>, + public EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_> +{ + +public: + + using Base = EpilogueWithBroadcast< + Shape_, + WarpMmaOperator_, + PartitionsK, + OutputTileIterator_, + TensorTileIterator_, + ElementVector_, + AccumulatorFragmentIterator_, + WarpTileIterator_, + SharedLoadIterator_, + OutputOp_, + Padding_, + FragmentsPerPartition, + IterationsUnroll, + true>; + + using BaseStreamK = EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_>; + + using Shape = Shape_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + + /// Fragment type used by the accumulator tile's fragment iterator + using AccumulatorFragment = typename Base::AccumulatorFragmentIterator::Fragment; + + /// Shared storage structure (shadows base) with additional SMEM buffer for reduction + using SharedStorage = typename Base::SharedStorage; + +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueStreamkWithBroadcast( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + BaseStreamK(thread_idx) + { } + + + /// Aggregates the accumulator sets shared by peer blocks in the global workspace, + /// performing epilogue computations, writing to output + CUTLASS_DEVICE + void reduce( + int peer_idx_begin, + int peer_idx_end, + int reduce_fragment_idx, + void *element_workspace, + OutputOp const &output_op, ///< Output operator + ElementVector const * broadcast_ptr, ///< Broadcast vector + OutputTileIterator destination_iterator, ///< Tile iterator for destination + OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) + { + // Reduce peer accumulator fragments into one fragment + AccumulatorFragment accum_fragment; + BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); + + // Store fragment to shared memory + this->warp_tile_iterator_.store(accum_fragment); + + __syncthreads(); + + Base::reduce(reduce_fragment_idx, output_op, broadcast_ptr, destination_iterator, source_iterator, tensor_iterator, problem_size, threadblock_offset); + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index 9c9f7165..7a97e0cb 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -863,6 +863,98 @@ private: frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); } } + + public: + /// Stream-K reduce helper + CUTLASS_DEVICE + void reduce( + int reduce_fragment_idx, ///< Reduce fragment index + OutputOp const &output_op, ///< Output operator + ElementVector const * broadcast_ptr, ///< Broadcast vector + OutputTileIterator destination_iterator, ///< Tile iterator for destination + OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix + OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix + TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) + { + + BroadcastFragment broadcast_fragment; + load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); + + // Initialize/load source-fragment data + typename OutputTileIterator::Fragment source_fragment1; + source_fragment1.clear(); + typename OutputTileIterator::Fragment source_fragment2; + source_fragment2.clear(); + + if (output_op.is_source_needed()) + { + source_iterator1 += reduce_fragment_idx; + source_iterator1.load(source_fragment1); + + source_iterator2 += reduce_fragment_idx; + source_iterator2.load(source_fragment2); + } + + // Load fragment from shared memory + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // Add fragments shared by other k partitions + if (kPartitionsK > 1) + { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + } + + // + // Apply output operation + // + + typename OutputTileIterator::Fragment frag_Z; + typename TensorTileIterator::Fragment frag_T; + + if (!output_op.is_source_needed()) { + apply_output_operator_source_not_needed_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + broadcast_fragment); + } else { + apply_output_operator_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + source_fragment1, + source_fragment2, + broadcast_fragment); + } + + // + // Conditionally store fragments + // + + if (OutputOp::kStoreZ) { + destination_iterator.store(frag_Z); + ++destination_iterator; + } + + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } }; @@ -1529,6 +1621,92 @@ private: frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); } } + + + public: + /// Stream-K reduce helper + CUTLASS_DEVICE + void reduce( + int reduce_fragment_idx, ///< Reduce fragment index + OutputOp const &output_op, ///< Output operator + ElementVector const * broadcast_ptr, ///< Broadcast vector + OutputTileIterator destination_iterator, ///< Tile iterator for destination + OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) + { + + BroadcastFragment broadcast_fragment; + load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); + + // Initialize/load source-fragment data + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + if (output_op.is_source_needed()) + { + source_iterator += reduce_fragment_idx; + source_iterator.load(source_fragment); + } + + // Load fragment from shared memory + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // Add fragments shared by other k partitions + if (kPartitionsK > 1) + { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + } + + // + // Apply output operation + // + + typename OutputTileIterator::Fragment frag_Z; + typename TensorTileIterator::Fragment frag_T; + + if (!output_op.is_source_needed()) { + apply_output_operator_source_not_needed_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + broadcast_fragment); + } else { + apply_output_operator_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + source_fragment, + broadcast_fragment); + } + + // + // Conditionally store fragments + // + + if (OutputOp::kStoreZ) { + destination_iterator.store(frag_Z); + ++destination_iterator; + } + + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } }; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h b/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h new file mode 100644 index 00000000..3e4ff3e5 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h @@ -0,0 +1,386 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a Stream-K GEMM kernel that can broadcast bias vector in the + epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM with a broadcast epilogue. + Supports +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementC_, ElementAccumulator_, ElementAccumulator_, + ElementC_, ElementC_, 128 / cutlass::sizeof_bits::value>, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone +> +class GemmUniversalStreamkWithBroadcast : + public GemmUniversalBase< + typename kernel::DefaultGemmStreamkWithBroadcast< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmStreamkWithBroadcast< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_ + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB> +class GemmUniversalStreamkWithBroadcast { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversalStreamkWithBroadcast< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator, + kTransformB, + kTransformA + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmUniversalStreamkWithBroadcast() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h b/include/cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h new file mode 100644 index 00000000..9c33039c --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a Stream-K GEMM that can broadcast a bias vector in the epilogue. + Similar structure to DefaultGemmWithBroadcast, but uses its own epilogue + (DefaultStreamkEpilogueWithBroadcastTensorOp) and its own GEMM kernel + (GemmStreamkWithFusedEpilogue). + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// + typename Enable = void +> +struct DefaultGemmStreamkWithBroadcast { + + using GemmBase = typename DefaultGemmUniversal< + ElementA_, LayoutA_, TransformA, kAlignmentA, + ElementB_, LayoutB_, TransformB, kAlignmentB, + ElementC_, LayoutC_, ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + Operator + >::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultStreamkEpilogueWithBroadcastTensorOp< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + GemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmStreamkWithFusedEpilogue< + typename GemmBase::Mma, + Epilogue, + ThreadblockSwizzle + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h new file mode 100644 index 00000000..6d6714d8 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h @@ -0,0 +1,2405 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Stream-K Gemm kernel compatible with fused epilogues + that broadcast a bias vector over the MMA output. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/layout.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/barrier.h" +#include "cutlass/block_striped.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool IsSingleSource = Epilogue_::kIsSingleSource +> +struct GemmStreamkWithFusedEpilogue; + +// GemmStreamkWithFusedEpilogue with two sources +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmStreamkWithFusedEpilogue { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + /// The per-thread tile of raw accumulators + using AccumulatorTile = typename Mma::FragmentC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Workspace bytes per thread block + static size_t const kWorkspaceBytesPerBlock = + __NV_STD_MAX( + kThreadCount * sizeof(AccumulatorTile), + Epilogue::kWorkspaceBytesPerBlock); + + /// Block-striped reduction utility + using BlockStripedReduceT = BlockStripedReduce; + + + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C1; + void const * ptr_C2; + void * ptr_D; + + void * ptr_Vector; + void * ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C1; + int64_t batch_stride_C2; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc1; + typename LayoutC::Stride::Index ldc2; + typename LayoutC::Stride::Index ldd; + typename LayoutC::Stride::Index ldr; + typename LayoutC::Stride::Index ldt; + + int avail_sms; /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) + + + // + // Methods + // + + /// Default Constructor + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C1(nullptr), + ptr_C2(nullptr), + ptr_D(nullptr), + avail_sms(-1) + {} + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C1, + void const * ptr_C2, + void * ptr_D, + void * ptr_Vector, + void * ptr_Tensor, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C1, + int64_t batch_stride_C2, + int64_t batch_stride_D, + int64_t batch_stride_Vector, + int64_t batch_stride_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc1, + typename LayoutC::Stride::Index ldc2, + typename LayoutC::Stride::Index ldd, + typename LayoutC::Stride::Index ldr, + typename LayoutC::Stride::Index ldt, + int avail_sms = -1) /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) + : + mode(mode), + problem_size(problem_size), + batch_count(batch_split), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C1(batch_stride_C1), + batch_stride_C2(batch_stride_C2), + batch_stride_Vector(batch_stride_Vector), + batch_stride_Tensor(batch_stride_Tensor), + lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt), avail_sms(avail_sms) + { + CUTLASS_TRACE_HOST("GemmStreamkWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << this->ldt); + CUTLASS_TRACE_HOST(" avail_sms: " << this->avail_sms); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + + /// Parameters structure + struct Params + { + public: + + // + // Data members + // + + void * ptr_A; + void * ptr_B; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + GemmUniversalMode mode; + + ThreadblockSwizzle block_mapping; + + void *barrier_workspace; + void *partials_workspace; + + typename EpilogueOutputOp::Params output_op; + + void * ptr_C1; + void * ptr_C2; + void * ptr_D; + void * ptr_Tensor; + void * ptr_Vector; + + typename Epilogue::OutputTileIterator::Params params_C1; + typename Epilogue::OutputTileIterator::Params params_C2; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + int64_t batch_stride_C1; + int64_t batch_stride_C2; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + typename LayoutC::Stride::Index ldr; + + protected: + + // + // Host-only dispatch-utilities + // + + /// Pad the given allocation size up to the nearest cache line + static size_t cacheline_align_up(size_t size) + { + static const int CACHELINE_SIZE = 128; + return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE; + } + + /// Get the workspace size needed for barrier + size_t get_barrier_workspace_size() const + { + // For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction, + // each reduction block needs its own synchronization flag. + int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); + int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks); + + return cacheline_align_up(sizeof(typename Barrier::T) * num_flags); + } + + /// Get the workspace size needed for intermediate partial sums + size_t get_partials_workspace_size() const + { + int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); + return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks); + } + + + public: + + // + // Host dispatch API + // + + /// Default constructor + Params() = default; + + /// Constructor + Params( + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + params_A(args.lda), + params_B(args.ldb), + params_C1(args.ldc1), + params_C2(args.ldc2), + params_D(args.ldd), + params_Tensor(args.ldt), + output_op(args.epilogue), + mode(args.mode), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C1(const_cast(args.ptr_C1)), + ptr_C2(const_cast(args.ptr_C2)), + ptr_D(args.ptr_D), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C1(args.batch_stride_C1), + batch_stride_C2(args.batch_stride_C2), + batch_stride_D(args.batch_stride_D), + batch_stride_Vector(args.batch_stride_Vector), + batch_stride_Tensor(args.batch_stride_Tensor), + barrier_workspace(nullptr), + partials_workspace(nullptr) + { + CUTLASS_TRACE_HOST("GemmStreamkWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << args.ldt); + CUTLASS_TRACE_HOST(" avail_sms: " << avail_sms); + + // Number of SMs to make available for StreamK decomposition + int avail_sms = (args.avail_sms == -1) ? + device_sms : + fast_min(args.avail_sms, device_sms); + + // Initialize the block mapping structure + block_mapping = ThreadblockSwizzle( + typename ThreadblockSwizzle::template KernelTraits(), + args.mode, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count, + sm_occupancy, + device_sms, + avail_sms); + } + + /// Returns the workspace size (in bytes) needed for these parameters + size_t get_workspace_size() const + { + return + get_barrier_workspace_size() + + get_partials_workspace_size(); + } + + /// Assign and initialize the specified workspace buffer. Assumes + /// the memory allocated to workspace is at least as large as get_workspace_size(). + Status init_workspace( + void *workspace, + cudaStream_t stream = nullptr) + { + uint8_t *ptr = static_cast(workspace); + + + // Establish partials workspace + partials_workspace = nullptr; + size_t partials_workspace_bytes = get_partials_workspace_size(); + if (partials_workspace_bytes > 0) + { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + partials_workspace = ptr; + ptr += partials_workspace_bytes; + } + + // Establish barrier workspace + barrier_workspace = nullptr; + size_t barrier_workspace_bytes = get_barrier_workspace_size(); + if (barrier_workspace_bytes > 0) + { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + barrier_workspace = ptr; + ptr += barrier_workspace_bytes; + } + + // Zero-initialize barrier workspace + if (barrier_workspace) + { + size_t barrier_workspace_bytes = get_barrier_workspace_size(); + + CUTLASS_TRACE_HOST(" Initialize " << barrier_workspace_bytes << " barrier bytes"); + + cudaError_t result = cudaMemsetAsync( + barrier_workspace, + 0, + barrier_workspace_bytes, + stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + + /// Returns the GEMM volume in thread block tiles + cutlass::gemm::GemmCoord get_tiled_shape() const + { + return block_mapping.tiled_shape(); + } + + /// Returns the total number of thread blocks to launch + int get_grid_blocks() const + { + dim3 grid_dims = get_grid_dims(); + return grid_dims.x * grid_dims.y * grid_dims.z; + } + + /// Returns the grid extents in thread blocks to launch + dim3 get_grid_dims() const + { + return block_mapping.get_grid_dims(); + } + + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + CUTLASS_HOST_DEVICE + void update(Arguments const &args) + { + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C1 = const_cast(args.ptr_C1); + ptr_C2 = const_cast(args.ptr_C2); + ptr_D = args.ptr_D; + + ptr_Vector = args.ptr_Vector; + ldr = args.ldr; + ptr_Tensor = args.ptr_Tensor; + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C1 = args.batch_stride_C1; + batch_stride_C2 = args.batch_stride_C2; + batch_stride_D = args.batch_stride_D; + batch_stride_Vector = args.batch_stride_Vector; + batch_stride_Tensor = args.batch_stride_Tensor; + + output_op = args.epilogue; + + CUTLASS_TRACE_HOST("GemmStreamkWithFusedEpilogue::Params::update()"); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + } + }; + + /// Tile work descriptor + struct TileWorkDesc + { + /// The linear tile index + int tile_idx; + + /// The location of this tile (in threadblock-tile coordinates) in the output matrix + cutlass::gemm::GemmCoord tiled_coord; + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + int iter_begin; + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + int k_begin; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + int k_end; + + /// The number of remaining MAC-iterations this threadblock will perform for this tile + int k_iters_remaining; + + // Whether this block will perform the first iteration of this tile + CUTLASS_DEVICE + bool tile_started() + { + return (k_begin == 0); + } + + // Whether this block will perform the last iteration of this tile + CUTLASS_DEVICE + bool tile_finished(Params const ¶ms) + { + return (k_end == params.block_mapping.problem_size.k()); + } + }; + + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + +protected: + + // + // Data members + // + + /// GEMM problem parameters + Params const ¶ms; + + /// Shared storage reference + SharedStorage &shared_storage; + + /// ID within the threadblock + int thread_idx; + + /// ID of warp + int warp_idx; + + /// ID of each thread within a warp + int lane_idx; + + /// Threadblock scoped epilogue + Epilogue epilogue; + + +public: + + // + // Host dispatch API + // + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmStreamkWithFusedEpilogue::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + +protected: + + // + // Device-only utility methods + // + + /// Iterator for fetching tile fragments from A + CUTLASS_DEVICE + typename Mma::IteratorA init_iterator_A( + TileWorkDesc &tile_work, + GemmUniversalMode mode) + { + // The input A matrix + ElementA *ptr_A = static_cast(params.ptr_A); + + // Update input pointers based on batched/array mode + if (mode == GemmUniversalMode::kBatched) { + ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A; + } + if (mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[tile_work.tiled_coord.k()]; + } + + int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; + int m_end = params.block_mapping.problem_size.m(); + return Mma::IteratorA( + params.params_A, + ptr_A, + { m_end, tile_work.k_end }, + threadIdx.x, + { m_begin, tile_work.k_begin }); + + } + + + /// Iterator for fetching tile fragments from B + CUTLASS_DEVICE + typename Mma::IteratorB init_iterator_B( + TileWorkDesc &tile_work, + GemmUniversalMode mode) + { + // The input B matrix + ElementB *ptr_B = static_cast(params.ptr_B); + + // Update input pointers based on batched/array mode + if (mode == GemmUniversalMode::kBatched) { + ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B; + } + if (mode == GemmUniversalMode::kArray) { + ptr_B = static_cast(params.ptr_B)[tile_work.tiled_coord.k()]; + } + + int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; + int n_end = params.block_mapping.problem_size.n(); + return Mma::IteratorB( + params.params_B, + ptr_B, + { tile_work.k_end, n_end }, + threadIdx.x, + { tile_work.k_begin, n_begin }); + } + + + CUTLASS_DEVICE + void init_dp_tile_work( + TileWorkDesc &tile_work, + int tile_idx) + { + // The linear tile index + tile_work.tile_idx = tile_idx; + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); + + // The number of MAC-iterations this threadblock will perform for this tile + tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_begin = 0; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_end = params.block_mapping.problem_size.k(); + + // The location of this tile (in threadblock-tile coordinates) in the output matrix + tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); + } + + + CUTLASS_DEVICE + void init_sk_tile_work( + TileWorkDesc &tile_work, + int tile_idx, + int block_iter_begin, + int block_iter_end) + { + // The linear tile index + tile_work.tile_idx = tile_idx; + + // The first global-scoped MAC-iteration for this tile + int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile(); + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + tile_work.iter_begin = max(block_iter_begin, tile_iter_begin); + + // The first tile-scoped MAC-iteration this threadblock will perform for this tile + int k_iter_begin = tile_work.iter_begin - tile_iter_begin; + + // The last (one past) tile-scoped MAC-iteration this threadblock will perform for this tile + int k_iter_end = block_iter_end - tile_iter_begin; + + // The number of MAC-iterations this threadblock will perform for this tile + tile_work.k_iters_remaining = k_iter_end - k_iter_begin; + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_begin = k_iter_begin * Mma::Shape::kK; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_end = min( + params.block_mapping.problem_size.k(), // extent of k domain + (k_iter_end * Mma::Shape::kK)); // extent of the threadblock's global iteration assignment + + // The location of this tile (in threadblock-tile coordinates) in the output matrix + tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); + } + + + /// Share accumulators with peers + CUTLASS_DEVICE + void share_accumulators( + AccumulatorTile const &accumulator_tile, + int block_idx, + int first_block_idx) + { + AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); + + int accum_tile_offset = first_block_idx * kThreadCount; + + if (block_idx == first_block_idx) + { + // First peer initializes the workspace partials + BlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); + } + else + { + // Subsequent peers atomically accumulate into the workspace partials + if (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) + { + // Non-deterministic reduction order: wait for the first peer to have initialized the partials before we add to them + Barrier::wait_lt(params.barrier_workspace, thread_idx, first_block_idx, 1); + } + else + { + // Turnstile reduction order: wait until the previous peer has written + int wait_count = block_idx - first_block_idx; + Barrier::wait_eq(params.barrier_workspace, thread_idx, first_block_idx, wait_count); + } + + // Perform reduction in workspace + BlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); + } + + // Signal our arrival + Barrier::arrive_inc(params.barrier_workspace, thread_idx, first_block_idx); + } + + + /// Acquire accumulators from peers + CUTLASS_DEVICE + void acquire_accumulators( + AccumulatorTile &accumulator_tile, + int block_idx, + int first_block_idx) + { + AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); + + // Wait for arrival + int num_carry_in = block_idx - first_block_idx; + Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in); + + // Load and add peer-partials accumulator tile to local accumulator tile + int accum_tile_offset = first_block_idx * kThreadCount; + BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx); + } + + + /// Perform epilogue computations and output + CUTLASS_DEVICE + void do_epilogue( + TileWorkDesc &tile_work, + AccumulatorTile &accumulator_tile) + { + ElementC *ptr_C1 = static_cast(params.ptr_C1); + ElementC *ptr_C2 = static_cast(params.ptr_C2); + ElementC *ptr_D = static_cast(params.ptr_D); + typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // Update pointers for batched/array mode(s) + if (params.mode == GemmUniversalMode::kBatched) { + ptr_C1 += tile_work.tiled_coord.k() * params.batch_stride_C1; + if (ptr_C2) { + ptr_C2 += tile_work.tiled_coord.k() * params.batch_stride_C2; + } + ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D; + if (ptr_Tensor) { + ptr_Tensor += tile_work.tiled_coord.k() * params.batch_stride_Tensor; + } + if (ptr_Vector) { + ptr_Vector += tile_work.tiled_coord.k() * params.batch_stride_Vector; + } + } + if (params.mode == GemmUniversalMode::kArray) { + ptr_C1 = static_cast(params.ptr_C1)[tile_work.tiled_coord.k()]; + if (ptr_C2) { + ptr_C2 = static_cast(params.ptr_C2)[tile_work.tiled_coord.k()]; + } + ptr_D = static_cast(params.ptr_D)[tile_work.tiled_coord.k()]; + if (ptr_Tensor) { + ptr_Tensor = static_cast(params.ptr_Tensor)[tile_work.tiled_coord.k()]; + } + if (ptr_Vector) { + ptr_Vector = static_cast(params.ptr_Vector)[tile_work.tiled_coord.k()]; + } + } + + // Location of this tile in item-coords + MatrixCoord threadblock_item_begin( + tile_work.tiled_coord.m() * Mma::Shape::kM, + tile_work.tiled_coord.n() * Mma::Shape::kN + ); + + // Tile iterator loading from residual1. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Tile iterator loading from residual2. + typename Epilogue::OutputTileIterator iterator_C2( + params.params_C2, + ptr_C2, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + ptr_Tensor, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_item_begin.column() + tile_work.tiled_coord.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue( + EpilogueOutputOp(params.output_op), + ptr_Vector, + iterator_D, + accumulator_tile, + iterator_C1, + iterator_C2, + tensor_iterator, + params.block_mapping.problem_size.mn(), + threadblock_item_begin); + } + + + CUTLASS_DEVICE + void separate_reduction(int reduce_idx) + { + int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx; + + // Reduce by sk-tile (every tile contributed to by one or more blocks) + reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments; + reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments; + + int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile(); + int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1; + + peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first); + peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last); + + // Wait for peers to complete + int peer_idx_end = peer_idx_last + 1; + int num_peers = peer_idx_end - peer_idx_begin; + Barrier::wait_eq_reset( + params.barrier_workspace, + thread_idx, + (reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx, + num_peers); + + /// The location of this tile (in threadblock-tile coordinates) in the output matrix + GemmCoord tiled_coord = params.block_mapping.get_tile_offset(reduce_tile_idx); + + // Location of this tile in item-coords + MatrixCoord threadblock_item_begin( + tiled_coord.m() * Mma::Shape::kM, + tiled_coord.n() * Mma::Shape::kN + ); + + ElementC *ptr_C1 = static_cast(params.ptr_C1); + ElementC *ptr_C2 = static_cast(params.ptr_C2); + ElementC *ptr_D = static_cast(params.ptr_D); + typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // Tile iterator loading from residual1. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Tile iterator loading from residual2. + typename Epilogue::OutputTileIterator iterator_C2( + params.params_C2, + ptr_C2, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + ptr_Tensor, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_item_begin.column() + tiled_coord.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue.reduce( + peer_idx_begin, + peer_idx_end, + reduce_fragment_idx, + params.partials_workspace, + EpilogueOutputOp(params.output_op), + ptr_Vector, + iterator_D, + iterator_C1, + iterator_C2, + tensor_iterator, + params.block_mapping.problem_size.mn(), + threadblock_item_begin); + } + + + CUTLASS_DEVICE + void process_tile( + TileWorkDesc tile_work, + int block_idx, + int dp_start_block_idx, + int block_iter_begin) + { + // Initialize input iterators + typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); + typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); + + // Initialize accumulators + AccumulatorTile accumulator_tile; + accumulator_tile.clear(); + + // Initialize MMA abstraction + Mma mma( + shared_storage.main_loop, + thread_idx, + warp_idx, + lane_idx); + + // Perform this tile's range of multiply-accumulate (MAC) iterations + mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); + + if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) || + (params.block_mapping.reduction_blocks == 0) || + (block_idx >= dp_start_block_idx)) + { + // + // Cooperative SK peer reduction or DP block + // + + int first_block_idx = params.block_mapping.get_first_block_idx(tile_work.tile_idx, block_idx); + + if (!tile_work.tile_finished(params)) { + // Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace + share_accumulators(accumulator_tile, block_idx, first_block_idx); + } + else + { + // DP blocks and "finishing" SK blocks must perform epilogue operations and write the output tile + if (!tile_work.tile_started()) + { + // A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks + acquire_accumulators(accumulator_tile, block_idx, first_block_idx); + } + + do_epilogue(tile_work, accumulator_tile); + } + } + else + { + // + // Separate peer reduction + // + + // Share accumulator partial sums with peer threadblock(s) through scratch workspace + epilogue.share(block_idx, params.partials_workspace, accumulator_tile, tile_work.tile_started()); + + // Signal arrival + Barrier::arrive_range_inc( + params.barrier_workspace, + thread_idx, + tile_work.tile_idx * Epilogue::kAccumulatorFragments, + Epilogue::kAccumulatorFragments); + } + } + + + /// Executes one GEMM + CUTLASS_DEVICE + void gemm() + { + // Initialize block's iteration range + int tile_idx = 0; + int block_iter_begin = 0; + int block_iters_remaining = 0; + + int block_idx = params.block_mapping.get_block_idx(); + + int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region(); + int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms; + int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks; + int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks; + + // Initialize tile work descriptor + TileWorkDesc tile_work; + + bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx); + bool sk_block = (block_idx < sk_padding_start_block_idx); + bool reduce_block = (block_idx >= reduce_start_block_idx) && + (block_idx < grid_padding_start_block_idx) && + (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed); + + if (dp_block) + { + // This is a DP block + int dp_block_idx = block_idx - dp_start_block_idx; + int first_dp_tile = (params.block_mapping.cohort_raster) ? 0 : params.block_mapping.sk_tiles; + + // Blocks in first DP wave get configured number of tiles + tile_idx = first_dp_tile + dp_block_idx; + int tile_allottment = params.block_mapping.dp_first_wave_tiles; + + // Blocks in subsequent DP waves get 1 tile + if (dp_block_idx >= params.block_mapping.avail_sms) { + tile_allottment = 1; + tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms; + } + + block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment; + + init_dp_tile_work(tile_work, tile_idx); + + // DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1) + if ((tile_idx < params.block_mapping.sk_tiles) || + (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) || + (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n())) + { + return; + } + } + else if (sk_block) + { + // This is a SK block + int block_iter_end; + params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end); + block_iters_remaining = block_iter_end - block_iter_begin; + + tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1); + init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); + } + else + { + if (reduce_block) + { + // This is a reduction threadblock + int reduce_block_idx = block_idx - reduce_start_block_idx; + separate_reduction(reduce_block_idx); + } + + return; + } + + // Iteration-processing loop body + CUTLASS_PRAGMA_NO_UNROLL + while (true) + { + // Perform this block's share of work for this tile + process_tile( + tile_work, + block_idx, + dp_start_block_idx, + block_iter_begin); + + block_iters_remaining -= tile_work.k_iters_remaining; + + if (block_iters_remaining == 0) + { + break; + } + + // Continue to next tile + __syncthreads(); + + if (block_idx >= dp_start_block_idx) + { + // DP block consume their tiles at stride + tile_idx += params.block_mapping.avail_sms; + init_dp_tile_work(tile_work, tile_idx); + } + else + { + // SK blocks consume their tiles in backwards order + tile_idx--; + init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); + } + } + + } + + +public: + + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmStreamkWithFusedEpilogue op(params, shared_storage); + op(); + } + + + // Constructor + CUTLASS_DEVICE + GemmStreamkWithFusedEpilogue( + Params const ¶ms, + SharedStorage &shared_storage) + : + params(params), + shared_storage(shared_storage), + thread_idx(threadIdx.x), + warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code + lane_idx(threadIdx.x % 32), + epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx) + {} + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()() { + // Generic SK code path + gemm(); + + } +}; + + +// GemmStreamkWithFusedEpilogue with one source +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmStreamkWithFusedEpilogue { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + /// The per-thread tile of raw accumulators + using AccumulatorTile = typename Mma::FragmentC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Workspace bytes per thread block + static size_t const kWorkspaceBytesPerBlock = + __NV_STD_MAX( + kThreadCount * sizeof(AccumulatorTile), + Epilogue::kWorkspaceBytesPerBlock); + + /// Block-striped reduction utility + using BlockStripedReduceT = BlockStripedReduce; + + + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + void * ptr_Vector; + void * ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldd; + typename LayoutC::Stride::Index ldr; + typename LayoutC::Stride::Index ldt; + + int avail_sms; /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) + + + // + // Methods + // + + /// Default Constructor + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + avail_sms(-1) + {} + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + void * ptr_Vector, + void * ptr_Tensor, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_Vector, + int64_t batch_stride_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldd, + typename LayoutC::Stride::Index ldr, + typename LayoutC::Stride::Index ldt, + int avail_sms = -1) /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) + : + mode(mode), + problem_size(problem_size), + batch_count(batch_split), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + batch_stride_Vector(batch_stride_Vector), + batch_stride_Tensor(batch_stride_Tensor), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt), avail_sms(avail_sms) + { + CUTLASS_TRACE_HOST("GemmStreamkWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << this->ldt); + CUTLASS_TRACE_HOST(" avail_sms: " << this->avail_sms); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + + /// Parameters structure + struct Params + { + + public: + + // + // Data members + // + + void * ptr_A; + void * ptr_B; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + GemmUniversalMode mode; + + ThreadblockSwizzle block_mapping; + + void *barrier_workspace; + void *partials_workspace; + + typename EpilogueOutputOp::Params output_op; + + void * ptr_C; + void * ptr_D; + void * ptr_Tensor; + void * ptr_Vector; + + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + + typename LayoutC::Stride::Index ldr; + + protected: + + // + // Host-only dispatch-utilities + // + + /// Pad the given allocation size up to the nearest cache line + static size_t cacheline_align_up(size_t size) + { + static const int CACHELINE_SIZE = 128; + return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE; + } + + /// Get the workspace size needed for barrier + size_t get_barrier_workspace_size() const + { + // For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction, + // each reduction block needs its own synchronization flag. + int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); + int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks); + + return cacheline_align_up(sizeof(typename Barrier::T) * num_flags); + } + + /// Get the workspace size needed for intermediate partial sums + size_t get_partials_workspace_size() const + { + int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); + return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks); + } + + + public: + // + // Host dispatch API + // + + /// Default constructor + Params() = default; + + /// Constructor + Params( + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + params_Tensor(args.ldt), + output_op(args.epilogue), + mode(args.mode), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + batch_stride_Vector(args.batch_stride_Vector), + batch_stride_Tensor(args.batch_stride_Tensor), + barrier_workspace(nullptr), + partials_workspace(nullptr) + { + CUTLASS_TRACE_HOST("GemmStreamkWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << args.ldt); + CUTLASS_TRACE_HOST(" avail_sms: " << avail_sms); + + // Number of SMs to make available for StreamK decomposition + int avail_sms = (args.avail_sms == -1) ? + device_sms : + fast_min(args.avail_sms, device_sms); + + // Initialize the block mapping structure + block_mapping = ThreadblockSwizzle( + typename ThreadblockSwizzle::template KernelTraits(), + args.mode, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count, + sm_occupancy, + device_sms, + avail_sms); + } + + /// Returns the workspace size (in bytes) needed for these parameters + size_t get_workspace_size() const + { + return + get_barrier_workspace_size() + + get_partials_workspace_size(); + } + + + /// Assign and initialize the specified workspace buffer. Assumes + /// the memory allocated to workspace is at least as large as get_workspace_size(). + Status init_workspace( + void *workspace, + cudaStream_t stream = nullptr) + { + uint8_t *ptr = static_cast(workspace); + + // Establish partials workspace + partials_workspace = nullptr; + size_t partials_workspace_bytes = get_partials_workspace_size(); + if (partials_workspace_bytes > 0) + { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + partials_workspace = ptr; + ptr += partials_workspace_bytes; + } + + // Establish barrier workspace + barrier_workspace = nullptr; + size_t barrier_workspace_bytes = get_barrier_workspace_size(); + if (barrier_workspace_bytes > 0) + { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + barrier_workspace = ptr; + ptr += barrier_workspace_bytes; + } + + // Zero-initialize barrier workspace + if (barrier_workspace) + { + size_t barrier_workspace_bytes = get_barrier_workspace_size(); + + CUTLASS_TRACE_HOST(" Initialize " << barrier_workspace_bytes << " barrier bytes"); + + cudaError_t result = cudaMemsetAsync( + barrier_workspace, + 0, + barrier_workspace_bytes, + stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + + /// Returns the GEMM volume in thread block tiles + cutlass::gemm::GemmCoord get_tiled_shape() const + { + return block_mapping.tiled_shape(); + } + + + /// Returns the total number of thread blocks to launch + int get_grid_blocks() const + { + dim3 grid_dims = get_grid_dims(); + return grid_dims.x * grid_dims.y * grid_dims.z; + } + + + /// Returns the grid extents in thread blocks to launch + dim3 get_grid_dims() const + { + return block_mapping.get_grid_dims(); + } + + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + CUTLASS_HOST_DEVICE + void update(Arguments const &args) + { + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + ptr_Vector = args.ptr_Vector; + ldr = args.ldr; + ptr_Tensor = args.ptr_Tensor; + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; + batch_stride_Vector = args.batch_stride_Vector; + batch_stride_Tensor = args.batch_stride_Tensor; + + output_op = args.epilogue; + + CUTLASS_TRACE_HOST("GemmStreamkWithFusedEpilogue::Params::update()"); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + } + }; + + /// Tile work descriptor + struct TileWorkDesc + { + /// The linear tile index + int tile_idx; + + /// The location of this tile (in threadblock-tile coordinates) in the output matrix + cutlass::gemm::GemmCoord tiled_coord; + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + int iter_begin; + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + int k_begin; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + int k_end; + + /// The number of remaining MAC-iterations this threadblock will perform for this tile + int k_iters_remaining; + + // Whether this block will perform the first iteration of this tile + CUTLASS_DEVICE + bool tile_started() + { + return (k_begin == 0); + } + + // Whether this block will perform the last iteration of this tile + CUTLASS_DEVICE + bool tile_finished(Params const ¶ms) + { + return (k_end == params.block_mapping.problem_size.k()); + } + }; + + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + +protected: + + // + // Data members + // + + /// GEMM problem parameters + Params const ¶ms; + + /// Shared storage reference + SharedStorage &shared_storage; + + /// ID within the threadblock + int thread_idx; + + /// ID of warp + int warp_idx; + + /// ID of each thread within a warp + int lane_idx; + + /// Threadblock scoped epilogue + Epilogue epilogue; + + +public: + + // + // Host dispatch API + // + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmStreamkWithFusedEpilogue::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + +protected: + + // + // Device-only utility methods + // + + /// Iterator for fetching tile fragments from A + CUTLASS_DEVICE + typename Mma::IteratorA init_iterator_A( + TileWorkDesc &tile_work, + GemmUniversalMode mode) + { + // The input A matrix + ElementA *ptr_A = static_cast(params.ptr_A); + + // Update input pointers based on batched/array mode + if (mode == GemmUniversalMode::kBatched) { + ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A; + } + if (mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[tile_work.tiled_coord.k()]; + } + + int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; + int m_end = params.block_mapping.problem_size.m(); + return Mma::IteratorA( + params.params_A, + ptr_A, + { m_end, tile_work.k_end }, + threadIdx.x, + { m_begin, tile_work.k_begin }); + + } + + + /// Iterator for fetching tile fragments from B + CUTLASS_DEVICE + typename Mma::IteratorB init_iterator_B( + TileWorkDesc &tile_work, + GemmUniversalMode mode) + { + // The input B matrix + ElementB *ptr_B = static_cast(params.ptr_B); + + // Update input pointers based on batched/array mode + if (mode == GemmUniversalMode::kBatched) { + ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B; + } + if (mode == GemmUniversalMode::kArray) { + ptr_B = static_cast(params.ptr_B)[tile_work.tiled_coord.k()]; + } + + int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; + int n_end = params.block_mapping.problem_size.n(); + return Mma::IteratorB( + params.params_B, + ptr_B, + { tile_work.k_end, n_end }, + threadIdx.x, + { tile_work.k_begin, n_begin }); + } + + + CUTLASS_DEVICE + void init_dp_tile_work( + TileWorkDesc &tile_work, + int tile_idx) + { + // The linear tile index + tile_work.tile_idx = tile_idx; + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); + + // The number of MAC-iterations this threadblock will perform for this tile + tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_begin = 0; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_end = params.block_mapping.problem_size.k(); + + // The location of this tile (in threadblock-tile coordinates) in the output matrix + tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); + } + + + CUTLASS_DEVICE + void init_sk_tile_work( + TileWorkDesc &tile_work, + int tile_idx, + int block_iter_begin, + int block_iter_end) + { + // The linear tile index + tile_work.tile_idx = tile_idx; + + // The first global-scoped MAC-iteration for this tile + int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile(); + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + tile_work.iter_begin = max(block_iter_begin, tile_iter_begin); + + // The first tile-scoped MAC-iteration this threadblock will perform for this tile + int k_iter_begin = tile_work.iter_begin - tile_iter_begin; + + // The last (one past) tile-scoped MAC-iteration this threadblock will perform for this tile + int k_iter_end = block_iter_end - tile_iter_begin; + + // The number of MAC-iterations this threadblock will perform for this tile + tile_work.k_iters_remaining = k_iter_end - k_iter_begin; + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_begin = k_iter_begin * Mma::Shape::kK; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_end = min( + params.block_mapping.problem_size.k(), // extent of k domain + (k_iter_end * Mma::Shape::kK)); // extent of the threadblock's global iteration assignment + + // The location of this tile (in threadblock-tile coordinates) in the output matrix + tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); + } + + + /// Share accumulators with peers + CUTLASS_DEVICE + void share_accumulators( + AccumulatorTile const &accumulator_tile, + int block_idx, + int first_block_idx) + { + AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); + + int accum_tile_offset = first_block_idx * kThreadCount; + + if (block_idx == first_block_idx) + { + // First peer initializes the workspace partials + BlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); + } + else + { + // Subsequent peers atomically accumulate into the workspace partials + if (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) + { + // Non-deterministic reduction order: wait for the first peer to have initialized the partials before we add to them + Barrier::wait_lt(params.barrier_workspace, thread_idx, first_block_idx, 1); + } + else + { + // Turnstile reduction order: wait until the previous peer has written + int wait_count = block_idx - first_block_idx; + Barrier::wait_eq(params.barrier_workspace, thread_idx, first_block_idx, wait_count); + } + + // Perform reduction in workspace + BlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx); + } + + // Signal our arrival + Barrier::arrive_inc(params.barrier_workspace, thread_idx, first_block_idx); + } + + + /// Acquire accumulators from peers + CUTLASS_DEVICE + void acquire_accumulators( + AccumulatorTile &accumulator_tile, + int block_idx, + int first_block_idx) + { + AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); + + // Wait for arrival + int num_carry_in = block_idx - first_block_idx; + Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in); + + // Load and add peer-partials accumulator tile to local accumulator tile + int accum_tile_offset = first_block_idx * kThreadCount; + BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx); + } + + + /// Perform epilogue computations and output + CUTLASS_DEVICE + void do_epilogue( + TileWorkDesc &tile_work, + AccumulatorTile &accumulator_tile) + { + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // Update pointers for batched/array mode(s) + if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C; + ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D; + if (ptr_Tensor) { + ptr_Tensor += tile_work.tiled_coord.k() * params.batch_stride_Tensor; + } + if (ptr_Vector) { + ptr_Vector += tile_work.tiled_coord.k() * params.batch_stride_Vector; + } + } + if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[tile_work.tiled_coord.k()]; + ptr_D = static_cast(params.ptr_D)[tile_work.tiled_coord.k()]; + if (ptr_Tensor) { + ptr_Tensor = static_cast(params.ptr_Tensor)[tile_work.tiled_coord.k()]; + } + if (ptr_Vector) { + ptr_Vector = static_cast(params.ptr_Vector)[tile_work.tiled_coord.k()]; + } + } + + // Location of this tile in item-coords + MatrixCoord threadblock_item_begin( + tile_work.tiled_coord.m() * Mma::Shape::kM, + tile_work.tiled_coord.n() * Mma::Shape::kN + ); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + ptr_Tensor, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_item_begin.column() + tile_work.tiled_coord.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue( + EpilogueOutputOp(params.output_op), + ptr_Vector, + iterator_D, + accumulator_tile, + iterator_C, + tensor_iterator, + params.block_mapping.problem_size.mn(), + threadblock_item_begin); + } + + + CUTLASS_DEVICE + void separate_reduction(int reduce_idx) + { + int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx; + + // Reduce by sk-tile (every tile contributed to by one or more blocks) + reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments; + reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments; + + int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile(); + int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1; + + peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first); + peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last); + + // Wait for peers to complete + int peer_idx_end = peer_idx_last + 1; + int num_peers = peer_idx_end - peer_idx_begin; + Barrier::wait_eq_reset( + params.barrier_workspace, + thread_idx, + (reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx, + num_peers); + + /// The location of this tile (in threadblock-tile coordinates) in the output matrix + GemmCoord tiled_coord = params.block_mapping.get_tile_offset(reduce_tile_idx); + + // Location of this tile in item-coords + MatrixCoord threadblock_item_begin( + tiled_coord.m() * Mma::Shape::kM, + tiled_coord.n() * Mma::Shape::kN + ); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + ptr_Tensor, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_item_begin.column() + tiled_coord.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue.reduce( + peer_idx_begin, + peer_idx_end, + reduce_fragment_idx, + params.partials_workspace, + EpilogueOutputOp(params.output_op), + ptr_Vector, + iterator_D, + iterator_C, + tensor_iterator, + params.block_mapping.problem_size.mn(), + threadblock_item_begin); + } + + + CUTLASS_DEVICE + void process_tile( + TileWorkDesc tile_work, + int block_idx, + int dp_start_block_idx, + int block_iter_begin) + { + // Initialize input iterators + typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); + typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); + + // Initialize accumulators + AccumulatorTile accumulator_tile; + accumulator_tile.clear(); + + // Initialize MMA abstraction + Mma mma( + shared_storage.main_loop, + thread_idx, + warp_idx, + lane_idx); + + // Perform this tile's range of multiply-accumulate (MAC) iterations + mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); + + if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) || + (params.block_mapping.reduction_blocks == 0) || + (block_idx >= dp_start_block_idx)) + { + // + // Cooperative SK peer reduction or DP block + // + + int first_block_idx = params.block_mapping.get_first_block_idx(tile_work.tile_idx, block_idx); + + if (!tile_work.tile_finished(params)) { + // Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace + share_accumulators(accumulator_tile, block_idx, first_block_idx); + } + else + { + // DP blocks and "finishing" SK blocks must perform epilogue operations and write the output tile + if (!tile_work.tile_started()) + { + // A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks + acquire_accumulators(accumulator_tile, block_idx, first_block_idx); + } + + do_epilogue(tile_work, accumulator_tile); + } + } + else + { + // + // Separate peer reduction + // + + // Share accumulator partial sums with peer threadblock(s) through scratch workspace + epilogue.share(block_idx, params.partials_workspace, accumulator_tile, tile_work.tile_started()); + + // Signal arrival + Barrier::arrive_range_inc( + params.barrier_workspace, + thread_idx, + tile_work.tile_idx * Epilogue::kAccumulatorFragments, + Epilogue::kAccumulatorFragments); + } + } + + + /// Executes one GEMM + CUTLASS_DEVICE + void gemm() + { + // Initialize block's iteration range + int tile_idx = 0; + int block_iter_begin = 0; + int block_iters_remaining = 0; + + int block_idx = params.block_mapping.get_block_idx(); + + int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region(); + int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms; + int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks; + int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks; + + // Initialize tile work descriptor + TileWorkDesc tile_work; + + bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx); + bool sk_block = (block_idx < sk_padding_start_block_idx); + bool reduce_block = (block_idx >= reduce_start_block_idx) && + (block_idx < grid_padding_start_block_idx) && + (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed); + + if (dp_block) + { + // This is a DP block + int dp_block_idx = block_idx - dp_start_block_idx; + int first_dp_tile = (params.block_mapping.cohort_raster) ? 0 : params.block_mapping.sk_tiles; + + // Blocks in first DP wave get configured number of tiles + tile_idx = first_dp_tile + dp_block_idx; + int tile_allottment = params.block_mapping.dp_first_wave_tiles; + + // Blocks in subsequent DP waves get 1 tile + if (dp_block_idx >= params.block_mapping.avail_sms) { + tile_allottment = 1; + tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms; + } + + block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment; + + init_dp_tile_work(tile_work, tile_idx); + + // DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1) + if ((tile_idx < params.block_mapping.sk_tiles) || + (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) || + (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n())) + { + return; + } + } + else if (sk_block) + { + // This is a SK block + int block_iter_end; + params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end); + block_iters_remaining = block_iter_end - block_iter_begin; + + tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1); + init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); + } + else + { + if (reduce_block) + { + // This is a reduction threadblock + int reduce_block_idx = block_idx - reduce_start_block_idx; + separate_reduction(reduce_block_idx); + } + + return; + } + + // Iteration-processing loop body + CUTLASS_PRAGMA_NO_UNROLL + while (true) + { + // Perform this block's share of work for this tile + process_tile( + tile_work, + block_idx, + dp_start_block_idx, + block_iter_begin); + + block_iters_remaining -= tile_work.k_iters_remaining; + + if (block_iters_remaining == 0) + { + break; + } + + // Continue to next tile + __syncthreads(); + + if (block_idx >= dp_start_block_idx) + { + // DP block consume their tiles at stride + tile_idx += params.block_mapping.avail_sms; + init_dp_tile_work(tile_work, tile_idx); + } + else + { + // SK blocks consume their tiles in backwards order + tile_idx--; + init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); + } + } + + } + + +public: + + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmStreamkWithFusedEpilogue op(params, shared_storage); + op(); + } + + + // Constructor + CUTLASS_DEVICE + GemmStreamkWithFusedEpilogue( + Params const ¶ms, + SharedStorage &shared_storage) + : + params(params), + shared_storage(shared_storage), + thread_idx(threadIdx.x), + warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code + lane_idx(threadIdx.x % 32), + epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx) + {} + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()() { + // Generic SK code path + gemm(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////