From 764b840d6fbf2840cc3ca5f36025eac2593616ae Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Tue, 10 Jan 2023 16:10:02 -0500 Subject: [PATCH] streamk example and performance tuning (#760) * streamk example and performance tuning * one missing file Co-authored-by: Haicheng Wu --- CHANGELOG.md | 2 +- README.md | 4 +- .../CMakeLists.txt | 35 ++ .../ampere_gemm_universal_streamk.cu | 592 ++++++++++++++++++ examples/CMakeLists.txt | 1 + examples/common/helper.h | 54 ++ include/cutlass/barrier.h | 46 +- .../gemm/kernel/gemm_universal_streamk.h | 215 +++++-- .../threadblock/threadblock_swizzle_streamk.h | 383 ++++++----- tools/library/scripts/generator.py | 5 +- 10 files changed, 1071 insertions(+), 266 deletions(-) create mode 100644 examples/47_ampere_gemm_universal_streamk/CMakeLists.txt create mode 100644 examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu diff --git a/CHANGELOG.md b/CHANGELOG.md index ba03351a..cca05679 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ # NVIDIA CUTLASS Changelog ## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19) -* Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. +* [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. * [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel. * [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency. * Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8. diff --git a/README.md b/README.md index 7765d40d..2a7802b8 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ supported at each level of the execution model hierarchy. # What's New in CUTLASS 2.11 CUTLASS 2.11 is an update to CUTLASS adding: -- Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. +- [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. - [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths. - [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel. - Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8. @@ -115,7 +115,7 @@ any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU. |NVIDIA A100|8.0|11.0|11.0| |NVIDIA A10 |8.6|11.1|11.1| |NVIDIA GeForce 3090|8.6|11.1|11.1| -|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8| +|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8; Mixed precision: 12.0| # Documentation diff --git a/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt b/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt new file mode 100644 index 00000000..055f65a1 --- /dev/null +++ b/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + 47_ampere_gemm_universal_streamk + ampere_gemm_universal_streamk.cu + ) + diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu new file mode 100644 index 00000000..717ae346 --- /dev/null +++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*************************************************************************************************** + Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the + "classic data-parallel" and "Split-K" decompositions. + + For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition + for Dense Matrix-Matrix Multiplication on the GPU" + + Requires NVIDIA Ampere or newer device (SM80+). + + - To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100) + + cutlass$ sudo nvidia-smi -pm 1 -i 0 + + cutlass$ sudo nvidia-smi -i 0 -pl 400 + + cutlass$ sudo nvidia-smi -i 0 -lgc 1005 + + - Build and run: + + cutlass$ mkdir build + + cutlass$ cd build + + cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80 + + cutlass/build$ make 47_ampere_gemm_universal_streamk + + cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk + + 10000 timing iterations of 2048 x 2048 x 2048 matrix-matrix multiply + + Basic data-parallel GEMM + Disposition: Passed + Avg runtime: 0.112633 ms + GFLOPs: 152530 + + StreamK GEMM with default load-balancing + Disposition: Passed + Avg runtime: 0.0941929 ms + GFLOPs: 182390 + Speedup vs Basic-DP: 1.196 + + StreamK emulating basic data-parallel GEMM + Disposition: Passed + Avg runtime: 0.113119 ms + GFLOPs: 151875 + Speedup vs Basic-DP: 0.996 + + Basic split-K GEMM with tile-splitting factor 2 + Disposition: Passed + Avg runtime: 0.104772 ms + GFLOPs: 163973 + + StreamK emulating Split-K GEMM with tile-splitting factor 2 + Disposition: Passed + Avg runtime: 0.105379 ms + GFLOPs: 163029 + Speedup vs Basic-SplitK: 0.994 + + **************************************************************************************************/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8) +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes) + +// Multiply-accumulate blocking/pipelining details +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape) +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape) +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape) +constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop + +// Epilogue output operator +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementC, // Element type for C and D matrix operands + AlignmentC, // Memory access granularity of C and D matrix in units of elements + ElementAccumulator, // Element type from internal accumaccumulation + ElementAccumulator>; // Data type used to compute linear combination + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +// Classic data-parallel device GEMM implementation type +using DeviceGemmBasic = cutlass::gemm::device::GemmUniversal< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + NumStages, + AlignmentA, + AlignmentB>; + +// StreamK device GEMM implementation type +using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference + NumStages, + AlignmentA, + AlignmentB>; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true) + {} + +}; + + +/// Command line options parsing +struct Options +{ + std::string command_name; + bool help; + cutlass::gemm::GemmCoord problem_size; + float alpha; + float beta; + int split_k_factor; + int avail_sms; + bool reference_check; + int iterations; + + cutlass::HostTensor tensor_a; + cutlass::HostTensor tensor_b; + cutlass::HostTensor tensor_c; + cutlass::HostTensor tensor_d; + cutlass::HostTensor tensor_ref_d; + + Options(std::string command_name) : + command_name(command_name), + help(false), + problem_size({2048, 2048, 2048}), + alpha(1.0f), + beta(0.0f), + split_k_factor(1), + avail_sms(-1), // Number of device SMs to use is unlimited + reference_check(true), + iterations(10000) + {} + + bool valid() const + { + return true; + } + + void parse(int argc, char const **args) + { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("split", split_k_factor); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const + { + out + << "Performs a GEMM computation.\n" + << "\n" + << "Options:\n" + << "\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m= GEMM M dimension\n" + << " --n= GEMM N dimension\n" + << " --k= GEMM K dimension\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --split= Split-K factor to emulate\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options +typename DeviceGemmBasic::Arguments args_from_options( + const DeviceGemmBasic &device_gemm, + const Options &options, + cutlass::HostTensor &tensor_a, + cutlass::HostTensor &tensor_b, + cutlass::HostTensor &tensor_c, + cutlass::HostTensor &tensor_d) +{ + return typename DeviceGemmBasic::Arguments( + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + options.problem_size, // problem_size + options.split_k_factor, // batch count / splitk slices + { // epilogue parameters + ElementAccumulator(options.alpha), + ElementAccumulator(options.beta) + }, + tensor_a.device_data(), // ptr_A + tensor_b.device_data(), // ptr_B + tensor_c.device_data(), // ptr_C + tensor_d.device_data(), // ptr_D + options.problem_size.mk().product(), // batch_stride_A + options.problem_size.nk().product(), // batch_stride_B + options.problem_size.mn().product(), // batch_stride_C + options.problem_size.mn().product(), // batch_stride_D + tensor_a.layout().stride(0), // stride_a + tensor_b.layout().stride(0), // stride_b + tensor_c.layout().stride(0), // stride_c + tensor_d.layout().stride(0)); // stride_d +} + +/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options +typename DeviceGemmStreamK::Arguments args_from_options( + const DeviceGemmStreamK &device_gemm, + const Options &options, + cutlass::HostTensor &tensor_a, + cutlass::HostTensor &tensor_b, + cutlass::HostTensor &tensor_c, + cutlass::HostTensor &tensor_d) +{ + return typename DeviceGemmStreamK::Arguments( + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + options.problem_size, // problem_size + options.split_k_factor, // batch count / splitk slices + { // epilogue parameters + ElementAccumulator(options.alpha), + ElementAccumulator(options.beta) + }, + tensor_a.device_data(), // ptr_A + tensor_b.device_data(), // ptr_B + tensor_c.device_data(), // ptr_C + tensor_d.device_data(), // ptr_D + options.problem_size.mk().product(), // batch_stride_A + options.problem_size.nk().product(), // batch_stride_B + options.problem_size.mn().product(), // batch_stride_C + options.problem_size.mn().product(), // batch_stride_D + tensor_a.layout().stride(0), // stride_a + tensor_b.layout().stride(0), // stride_b + tensor_c.layout().stride(0), // stride_c + tensor_d.layout().stride(0), // stride_d + options.avail_sms); // avail_sms +} + + +/// Execute a given example GEMM computation +template +Result run(std::string description, Options &options) +{ + // Display test description + std::cout << std::endl << description << std::endl; + + // Zero-initialize test output matrix D + cutlass::reference::host::TensorFill(options.tensor_d.host_view()); + options.tensor_d.sync_device(); + + // Instantiate CUTLASS kernel depending on templates + DeviceGemmT device_gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT + auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = DeviceGemmT::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + CUTLASS_CHECK(device_gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(device_gemm()); + + // Copy output data from CUTLASS and reference kernel to host for comparison + options.tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = cutlass::reference::host::TensorEquals( + options.tensor_d.host_view(), + options.tensor_ref_d.host_view()); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(device_gemm()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + } + + if (!result.passed) { + exit(-1); + } + + return result; +} + + +/// Program entrypoint +int main(int argc, const char **argv) +{ + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + // Current device must must have compute capability at least 80 + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + if (!((props.major * 10 + props.minor) >= 80)) + { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + // Parse commandline options + Options options("ampere_streamk_gemm"); + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + std::cout << + options.iterations << " timing iterations of " << + options.problem_size.m() << " x " << + options.problem_size.n() << " x " << + options.problem_size.k() << " matrix-matrix multiply" << std::endl; + + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + + // + // Initialize GEMM datasets + // + + // Initialize tensors using CUTLASS helper functions + options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K + options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N + options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N + options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel + options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel + + // Fill matrix A on host with uniform-random data [4, -4] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_a.host_view(), + 1, + ElementA(2), + ElementA(-2), + 0); + + // Fill matrix B on host with uniform-random data [4, -4] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_b.host_view(), + 1, + ElementB(2), + ElementB(-2), + 0); + + // Fill matrix C on host with uniform-random data [4, -4] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_c.host_view(), + 1, + ElementC(2), + ElementC(-2), + 0); + + + // + // Compute reference output + // + + // Copy data from host to GPU + options.tensor_a.sync_device(); + options.tensor_b.sync_device(); + options.tensor_c.sync_device(); + + // Zero-initialize reference output matrix D + cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view()); + options.tensor_ref_d.sync_device(); + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + options.problem_size, + ElementAccumulator(options.alpha), + options.tensor_a.device_ref(), + options.tensor_b.device_ref(), + ElementAccumulator(options.beta), + options.tensor_c.device_ref(), + options.tensor_ref_d.device_ref()); + + // Wait for kernels to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Copy output data from reference kernel to host for comparison + options.tensor_ref_d.sync_host(); + + + // + // Evaluate CUTLASS kernels + // + + // Test default operation + if (options.split_k_factor == 1) + { + // Compare basic data-parallel version versus StreamK version using default load-balancing heuristics + Result basic_dp = run("Basic data-parallel GEMM", options); + Result streamk_default = run("StreamK GEMM with default load-balancing", options); + + printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms)); + + // Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1 + options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing) + Result streamk_dp = run("StreamK emulating basic data-parallel GEMM", options); + options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs) + + printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms)); + + options.split_k_factor++; // Increment splitting factor for next evaluation + + } + + // Show that StreamK can emulate "Split-K" with a tile-splitting factor + Result basic_splitk = run( + std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), + options); + + Result streamk_splitk = run( + std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), + options); + + printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms)); + + return 0; +} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a4c132c1..246b3299 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -124,6 +124,7 @@ foreach(EXAMPLE 43_ell_block_sparse_gemm 45_dual_gemm 46_depthwise_simt_conv2dfprop + 47_ampere_gemm_universal_streamk ) add_subdirectory(${EXAMPLE}) diff --git a/examples/common/helper.h b/examples/common/helper.h index 2affd96c..ba04113c 100644 --- a/examples/common/helper.h +++ b/examples/common/helper.h @@ -2,6 +2,9 @@ #include "cuda_runtime.h" +/** + * Panic wrapper for unwinding CUTLASS errors + */ #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ @@ -12,6 +15,10 @@ } \ } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ #define CUDA_CHECK(status) \ { \ cudaError_t error = status; \ @@ -21,3 +28,50 @@ exit(EXIT_FAILURE); \ } \ } + + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream + */ +struct GpuTimer +{ + cudaStream_t _stream_id; + cudaEvent_t _start; + cudaEvent_t _stop; + + /// Constructor + GpuTimer() : _stream_id(0) + { + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); + } + + /// Destructor + ~GpuTimer() + { + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) + { + _stream_id = stream_id; + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); + } + + /// Stop the timer + void stop() + { + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() + { + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; + } +}; diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index dbdb9cbc..9e1a27fd 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -57,34 +57,25 @@ public: protected: - /// Load flag, as a strong operation (int specialization) + /// Load flag, as a strong acquire operation (int specialization) CUTLASS_DEVICE - static int ld_strong(int *ptr) + static int ld_acquire(int *ptr) { int state = 0; #if (__CUDA_ARCH__ >= 700) - /// SM70 and newer use memory consistency qualifiers - asm volatile ("ld.global.relaxed.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); + /// SM70 and newer use memory consistency qualifiers + + // Acquire pattern using acquire modifier + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); + #else - asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); + asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); #endif // (__CUDA_ARCH__ >= 700) return state; } - /// Store flag, as a strong operation (int specialization) - CUTLASS_DEVICE - static void st_strong(int *ptr, int val) - { -#if (__CUDA_ARCH__ >= 700) - /// SM70 and newer use memory consistency qualifiers - asm volatile ("st.global.relaxed.gpu.b32 [%0], %1;\n" : : "l"(ptr), "r"(val)); -#else - asm volatile ("st.cg.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val)); -#endif // (__CUDA_ARCH__ >= 700) - } - /// Reduce into flag, with release pattern (int specialization) CUTLASS_DEVICE @@ -92,11 +83,16 @@ protected: { #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) #if (__CUDA_ARCH__ >= 700) - /// SM70 and newer use memory consistency qualifiers - asm volatile ("red.release.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); + /// SM70 and newer use memory consistency qualifiers + + // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data + // that was weakly-written by other threads prior to the last syncthreads) + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); + #else - __threadfence(); - atomicAdd(ptr, val); + __threadfence(); + atomicAdd(ptr, val); #endif // (__CUDA_ARCH__ >= 700) #endif } @@ -115,7 +111,7 @@ public: { // Spin-loop #pragma unroll 1 - while(ld_strong(flag_ptr) < count) {} + while(ld_acquire(flag_ptr) < count) {} } __syncthreads(); @@ -133,9 +129,8 @@ public: { // Spin-loop #pragma unroll 1 - while(ld_strong(flag_ptr) != val) {} + while(ld_acquire(flag_ptr) != val) {} } - __syncthreads(); #endif } @@ -166,7 +161,8 @@ public: __syncthreads(); - if (thread_idx == 0) { + if (thread_idx == 0) + { red_release(flag_ptr, 1); } #endif diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index a354ee01..f1934f7e 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -124,7 +124,7 @@ public: GemmUniversalMode mode; GemmCoord problem_size; - int batch_count; + int batch_count; // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor typename EpilogueOutputOp::Params epilogue; @@ -148,7 +148,7 @@ public: typename LayoutC::Stride::LongIndex ldc; typename LayoutC::Stride::LongIndex ldd; - int sm_limit; /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance + int avail_sms; /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) // @@ -159,15 +159,18 @@ public: Arguments(): mode(GemmUniversalMode::kGemm), batch_count(1), - ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), - sm_limit(-1) + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + avail_sms(-1) {} /// Constructor Arguments( GemmUniversalMode mode, GemmCoord problem_size, - int batch_count, + int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) typename EpilogueOutputOp::Params epilogue, void const * ptr_A, void const * ptr_B, @@ -181,15 +184,15 @@ public: typename LayoutB::Stride stride_b, typename LayoutC::Stride stride_c, typename LayoutC::Stride stride_d, - int sm_limit = -1 /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance + int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) ): mode(mode), problem_size(problem_size), - batch_count(batch_count), + batch_count(batch_split), epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), - stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), sm_limit(sm_limit) + stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), avail_sms(avail_sms) { CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size); } @@ -198,7 +201,7 @@ public: Arguments( GemmUniversalMode mode, GemmCoord problem_size, - int batch_count, + int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) typename EpilogueOutputOp::Params epilogue, void const * ptr_A, void const * ptr_B, @@ -212,15 +215,15 @@ public: typename LayoutB::Stride::LongIndex ldb, typename LayoutC::Stride::LongIndex ldc, typename LayoutC::Stride::LongIndex ldd, - int sm_limit = -1 /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance + int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) ): mode(mode), problem_size(problem_size), - batch_count(batch_count), + batch_count(batch_split), epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), sm_limit(sm_limit) + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), avail_sms(avail_sms) { stride_a = make_Coord(lda); stride_b = make_Coord(ldb); @@ -254,29 +257,36 @@ public: // Data members // - ThreadblockSwizzle block_mapping; + void * ptr_A; + void * ptr_B; typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename EpilogueOutputOp::Params output_op; - - GemmUniversalMode mode; - - void * ptr_A; - void * ptr_B; - void * ptr_C; - void * ptr_D; int64_t batch_stride_A; int64_t batch_stride_B; - int64_t batch_stride_C; - int64_t batch_stride_D; + + GemmUniversalMode mode; + + ThreadblockSwizzle block_mapping; + + bool quick_dp; void *barrier_workspace; void *partials_workspace; + typename EpilogueOutputOp::Params output_op; + + void * ptr_D; + void * ptr_C; + + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::Params params_C; + + int64_t batch_stride_D; + int64_t batch_stride_C; + + protected: // @@ -295,7 +305,7 @@ public: { // For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction, // each reduction block needs its own synchronization flag. - int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region; + int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks); return cacheline_align_up(sizeof(typename Barrier::T) * num_flags); @@ -304,7 +314,7 @@ public: /// Get the workspace size needed for intermediate partial sums size_t get_partials_workspace_size() const { - int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region; + int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region(); return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks); } @@ -343,9 +353,9 @@ public: partials_workspace(nullptr) { // Number of SMs to make available for StreamK decomposition - int avail_sms = (args.sm_limit == -1) ? + int avail_sms = (args.avail_sms == -1) ? device_sms : - fast_min(args.sm_limit, device_sms); + fast_min(args.avail_sms, device_sms); // Initialize the block mapping structure block_mapping = ThreadblockSwizzle( @@ -355,7 +365,15 @@ public: {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count, sm_occupancy, + device_sms, avail_sms); + + quick_dp = + (block_mapping.sk_waves == 0) && + (mode == GemmUniversalMode::kGemm) && + !block_mapping.cohort_raster && + !EpilogueOutputOp(output_op).is_source_needed(); + } @@ -426,7 +444,7 @@ public: /// Returns the GEMM volume in thread block tiles cutlass::gemm::GemmCoord get_tiled_shape() const { - return block_mapping.tiled_shape; + return block_mapping.tiled_shape(); } @@ -533,9 +551,6 @@ protected: /// ID of each thread within a warp int lane_idx; - /// Block index - int block_idx; - /// Threadblock scoped epilogue Epilogue epilogue; @@ -640,16 +655,18 @@ protected: /// Iterator for fetching tile fragments from A CUTLASS_DEVICE - typename Mma::IteratorA init_iterator_A(TileWorkDesc &tile_work) + typename Mma::IteratorA init_iterator_A( + TileWorkDesc &tile_work, + GemmUniversalMode mode) { // The input A matrix ElementA *ptr_A = static_cast(params.ptr_A); // Update input pointers based on batched/array mode - if (params.mode == GemmUniversalMode::kBatched) { + if (mode == GemmUniversalMode::kBatched) { ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A; } - if (params.mode == GemmUniversalMode::kArray) { + if (mode == GemmUniversalMode::kArray) { ptr_A = static_cast(params.ptr_A)[tile_work.tiled_coord.k()]; } @@ -667,16 +684,18 @@ protected: /// Iterator for fetching tile fragments from B CUTLASS_DEVICE - typename Mma::IteratorB init_iterator_B(TileWorkDesc &tile_work) + typename Mma::IteratorB init_iterator_B( + TileWorkDesc &tile_work, + GemmUniversalMode mode) { // The input B matrix ElementB *ptr_B = static_cast(params.ptr_B); // Update input pointers based on batched/array mode - if (params.mode == GemmUniversalMode::kBatched) { + if (mode == GemmUniversalMode::kBatched) { ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B; } - if (params.mode == GemmUniversalMode::kArray) { + if (mode == GemmUniversalMode::kArray) { ptr_B = static_cast(params.ptr_B)[tile_work.tiled_coord.k()]; } @@ -700,10 +719,10 @@ protected: tile_work.tile_idx = tile_idx; // The first global-scoped MAC-iteration this threadblock will perform for this tile - tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile; + tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); // The number of MAC-iterations this threadblock will perform for this tile - tile_work.k_iters_remaining = params.block_mapping.iters_per_tile; + tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile tile_work.k_begin = 0; @@ -727,7 +746,7 @@ protected: tile_work.tile_idx = tile_idx; // The first global-scoped MAC-iteration for this tile - int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile; + int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile(); // The first global-scoped MAC-iteration this threadblock will perform for this tile tile_work.iter_begin = max(block_iter_begin, tile_iter_begin); @@ -756,7 +775,10 @@ protected: /// Share accumulators with peers CUTLASS_DEVICE - void share_accumulators(AccumulatorTile const &accumulator_tile, int first_block_idx) + void share_accumulators( + AccumulatorTile const &accumulator_tile, + int block_idx, + int first_block_idx) { AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); @@ -795,6 +817,7 @@ protected: CUTLASS_DEVICE void acquire_accumulators( AccumulatorTile &accumulator_tile, + int block_idx, int first_block_idx) { AccumulatorTile *accum_tile_workspace = reinterpret_cast(params.partials_workspace); @@ -868,8 +891,8 @@ protected: reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments; reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments; - int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile; - int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile - 1; + int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile(); + int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1; peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first); peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last); @@ -895,16 +918,6 @@ protected: ElementC *ptr_C = static_cast(params.ptr_C); ElementC *ptr_D = static_cast(params.ptr_D); - // Update pointers for batched/array mode(s) - if (params.mode == GemmUniversalMode::kBatched) { - ptr_C += tiled_coord.k() * params.batch_stride_C; - ptr_D += tiled_coord.k() * params.batch_stride_D; - } - if (params.mode == GemmUniversalMode::kArray) { - ptr_C = static_cast(params.ptr_C)[tiled_coord.k()]; - ptr_D = static_cast(params.ptr_D)[tiled_coord.k()]; - } - // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( params.params_C, @@ -936,12 +949,13 @@ protected: CUTLASS_DEVICE void process_tile( TileWorkDesc tile_work, + int block_idx, int dp_start_block_idx, int block_iter_begin) { // Initialize input iterators - typename Mma::IteratorA iterator_A = init_iterator_A(tile_work); - typename Mma::IteratorB iterator_B = init_iterator_B(tile_work); + typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); + typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); // Initialize accumulators AccumulatorTile accumulator_tile; @@ -968,7 +982,7 @@ protected: if (!tile_work.tile_finished(params)) { // Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace - share_accumulators(accumulator_tile, first_block_idx); + share_accumulators(accumulator_tile, block_idx, first_block_idx); } else { @@ -976,7 +990,7 @@ protected: if (!tile_work.tile_started()) { // A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks - acquire_accumulators(accumulator_tile, first_block_idx); + acquire_accumulators(accumulator_tile, block_idx, first_block_idx); } do_epilogue(tile_work, accumulator_tile); @@ -1008,11 +1022,12 @@ protected: // Initialize block's iteration range int tile_idx, block_iter_begin, block_iters_remaining; - int sk_padding_start_block_idx = params.block_mapping.sk_regions * params.block_mapping.sk_blocks_per_region; + int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region(); int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms; int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks; int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks; + int block_idx = params.block_mapping.get_block_idx(); if (block_idx < sk_padding_start_block_idx) { // This is a SK block @@ -1044,8 +1059,9 @@ protected: } block_iter_begin = 0; - block_iters_remaining = params.block_mapping.iters_per_tile * tile_allottment; + block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment; } + else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) && (block_idx < grid_padding_start_block_idx)) { @@ -1072,8 +1088,8 @@ protected: // DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1) if ((tile_idx < params.block_mapping.sk_tiles) || - (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape.m()) || - (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape.n())) + (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) || + (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n())) { break; } @@ -1084,7 +1100,7 @@ protected: } // Perform this block's share of work for this tile - process_tile(tile_work, dp_start_block_idx, block_iter_begin); + process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin); // Update remaining work for this block block_iters_remaining -= tile_work.k_iters_remaining; @@ -1110,6 +1126,64 @@ protected: } + + /// Executes one DP-only GEMM + CUTLASS_DEVICE + void gemm_dp() + { + int block_idx = blockIdx.x; + int tile_idx = block_idx; + + TileWorkDesc tile_work; + tile_work.tile_idx = tile_idx; + tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); + tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); + tile_work.k_begin = 0; + tile_work.k_end = params.block_mapping.problem_size.k(); + tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx); + + // Initialize input iterators + typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); + typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); + + // Initialize accumulators + AccumulatorTile accumulator_tile; + accumulator_tile.clear(); + + // Perform this tile's range of multiply-accumulate (MAC) iterations + Mma mma( + shared_storage.main_loop, + thread_idx, + warp_idx, + lane_idx); + + mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); + + ElementC *ptr_D = static_cast(params.ptr_D); + + // Location of this tile in item-coords + MatrixCoord threadblock_item_begin( + tile_work.tiled_coord.m() * Mma::Shape::kM, + tile_work.tiled_coord.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Execute the epilogue operator to update the destination tensor. + epilogue( + EpilogueOutputOp(params.output_op), + iterator_D, + accumulator_tile); + } + + + public: // @@ -1138,7 +1212,6 @@ public: thread_idx(threadIdx.x), warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code lane_idx(threadIdx.x % 32), - block_idx(params.block_mapping.get_block_idx()), epilogue( shared_storage.epilogue, thread_idx, @@ -1151,7 +1224,17 @@ public: CUTLASS_DEVICE void operator()() { - // Do the GEMM +#if (__CUDACC_VER_MAJOR__ > 10) + if (params.quick_dp) + { + // Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray + // modes will only be launched using a data-parallel configurations) + gemm_dp(); + return; + } +#endif + + // Generic SK code path gemm(); } diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h index 499157ea..f0ef1b06 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -115,28 +115,21 @@ struct ThreadblockSwizzleStreamK { // Member state // + /// The 3D value-extents of the GEMM computation volume (m,n,k) GemmCoord problem_size; - /// The 2D tile-extents of the output matrix (m,n) - GemmCoord tiled_shape; + /// Div/mod accelerators + FastDivmod div_mod_tiled_shape_m; + FastDivmod div_mod_tiled_shape_n; + FastDivmod div_mod_tiled_cohort_shape_n; + FastDivmod div_mod_iters_per_tile; - /// Number of iterations per output tile - int iters_per_tile; + /// Whether to perform cohort CTA rasterization + bool cohort_raster; - /// Number of reduction blocks in the grid - int reduction_blocks; - - int dp_blocks; /// Number of data-parallel thread blocks in the grid - int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce - - int sk_tiles; - int sk_regions; - int sk_blocks_per_region; - int sk_big_blocks_per_region; - int sk_iters_per_region; - int sk_iters_per_normal_block; /// Number of iterations for normal SK-blocks - int sk_waves; /// Number of SK waves in the grid + // Whether to pad and remap block indices + bool remap_block_indices; /// CTA occupancy per SM int sm_occupancy; @@ -144,21 +137,26 @@ struct ThreadblockSwizzleStreamK { /// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size) int avail_sms; - /// Whether to perform cohort CTA rasterization - bool cohort_raster; + int dp_blocks; /// Number of data-parallel thread blocks in the grid + int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce + + /// Number of reduction blocks in the grid + int reduction_blocks; + + int sk_waves; + int sk_tiles; + int sk_big_blocks_per_region; + int sk_iters_per_region; /// Div/mod accelerators - struct - { - FastDivmod tiled_shape_m; - FastDivmod tiled_shape_n; - FastDivmod tiled_cohort_shape_n; - FastDivmod iters_per_tile; - FastDivmod sk_iters_per_normal_block; - FastDivmod sk_iters_per_big_block; - FastDivmod sk_iters_per_region; - FastDivmod sk_blocks_per_region; - } div_mod; + FastDivmod div_mod_sk_iters_per_normal_block; + FastDivmod div_mod_sk_iters_per_big_block; + FastDivmod div_mod_sk_iters_per_region; + FastDivmod div_mod_sk_regions; //!! used in block map + FastDivmod div_mod_sk_blocks_per_region; //!! used in block map + + /// The batch count + int batch_count; // @@ -169,6 +167,43 @@ struct ThreadblockSwizzleStreamK { CUTLASS_HOST_DEVICE ThreadblockSwizzleStreamK() {} + /// Returns the GEMM volume in thread block tiles + CUTLASS_HOST_DEVICE + GemmCoord tiled_shape() const + { + return GemmCoord( + static_cast(div_mod_tiled_shape_m), + static_cast(div_mod_tiled_shape_n), + batch_count); + } + + /// Number of iterations per output tile + CUTLASS_HOST_DEVICE + int iters_per_tile() const + { + return static_cast(div_mod_iters_per_tile); + } + + /// Number of iterations for normal SK-blocks + CUTLASS_HOST_DEVICE + int sk_iters_per_normal_block() const + { + return static_cast(div_mod_sk_iters_per_normal_block); + } + + /// Number of SK regions + CUTLASS_HOST_DEVICE + int sk_regions() const + { + return static_cast(div_mod_sk_regions); + } + + /// Number of SK blocks per region (splitting factor) + CUTLASS_HOST_DEVICE + int sk_blocks_per_region() const + { + return static_cast(div_mod_sk_blocks_per_region); + } // @@ -179,26 +214,27 @@ struct ThreadblockSwizzleStreamK { void Print() { #ifndef __CUDA_ARCH__ - int tiles = tiled_shape.m() * tiled_shape.n(); + auto tiles = tiled_shape().mn().product(); std::cout << "problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" << - ", reduction_blocks: " << reduction_blocks << - ", dp_blocks: " << dp_blocks << - ", sk_blocks_per_region: " << sk_blocks_per_region << - ", sk_regions: " << sk_regions << - ", sk_waves: " << sk_waves << - ", sk_iters_per_normal_block: " << sk_iters_per_normal_block << - ", sk_big_blocks_per_region: " << sk_big_blocks_per_region << - ", dp_first_wave_tiles: " << dp_first_wave_tiles << - ", tiled_shape: (" << tiled_shape.m() << "," << tiled_shape.n() << ")" << + ", tiled_shape: (" << tiled_shape().m() << "," << tiled_shape().n() << ")" << ", tiles: " << tiles << - ", iters_per_tile: " << iters_per_tile << ", dp_tiles: " << tiles - sk_tiles << ", sk_tiles: " << sk_tiles << - ", avail_sms: " << avail_sms << + ", iters_per_tile: " << iters_per_tile() << + ", reduction_blocks: " << reduction_blocks << + ", dp_blocks: " << dp_blocks << + ", dp_waves: " << dp_blocks / avail_sms << + ", dp_first_wave_tiles: " << dp_first_wave_tiles << + ", sk_blocks_per_region: " << sk_blocks_per_region() << + ", sk_regions: " << sk_regions() << + ", sk_waves: " << sk_waves << + ", sk_iters_per_normal_block: " << sk_iters_per_normal_block() << + ", sk_big_blocks_per_region: " << sk_big_blocks_per_region << + ", remap_block_indices: " << remap_block_indices << + ", cohort_raster: " << cohort_raster << ", sm_occupancy: " << sm_occupancy << ", avail_sms: " << avail_sms << - ", cohort_raster: " << cohort_raster << ", num_blocks: " << get_num_blocks() << "\n\n"; #endif @@ -368,30 +404,37 @@ struct ThreadblockSwizzleStreamK { GemmUniversalMode const mode_, GemmCoord const problem_size_, GemmCoord const tile_size_, - int const batch_count_, /// Batch count (when mode_ == GemmUniversalMode::kBatched) or split-K-override splitting factor (when mode_ == GemmUniversalMode::kGemm) + int const batch_split_, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) int const sm_occupancy_, - int const avail_sms_) + int const device_sms_, + int const avail_sms_) /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) : problem_size(problem_size_), - tiled_shape( - (problem_size.m() + tile_size_.m() - 1) / tile_size_.m(), - (problem_size.n() + tile_size_.n() - 1) / tile_size_.n(), - (mode_ == GemmUniversalMode::kBatched) ? batch_count_ : 1), - iters_per_tile((problem_size.k() + tile_size_.k() - 1) / tile_size_.k()), + batch_count((mode_ == GemmUniversalMode::kBatched) ? batch_split_ : 1), reduction_blocks(0), dp_blocks(0), dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks sk_tiles(0), - sk_regions(1), // Default: a single region of iteration space (across all SK tiles) - sk_blocks_per_region(0), sk_big_blocks_per_region(0), sk_iters_per_region(0), - sk_iters_per_normal_block(0), sk_waves(0), sm_occupancy(sm_occupancy_), + remap_block_indices(false), avail_sms(fast_max(1, avail_sms_)), cohort_raster(false) { + int gpu_occupancy = device_sms_ * sm_occupancy; + int iters_per_tile = (problem_size.k() + tile_size_.k() - 1) / tile_size_.k(); + int sk_iters_per_normal_block = 0; + + int sk_regions = 1; // Default: a single region of iteration space (across all SK tiles) + int sk_blocks_per_region = 0; + + GemmCoord tiled_shape( + (problem_size.m() + tile_size_.m() - 1) / tile_size_.m(), + (problem_size.n() + tile_size_.n() - 1) / tile_size_.n(), + batch_count); + size_t problem_bytes = (sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) + (sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) + @@ -401,7 +444,6 @@ struct ThreadblockSwizzleStreamK { float flops_per_byte = float(problem_flops) / float(problem_bytes); - int gpu_occupancy = avail_sms * sm_occupancy; int output_tiles = tiled_shape.m() * tiled_shape.n(); int waves = (output_tiles + avail_sms - 1) / avail_sms; float dp_efficiency = float(output_tiles) / float(waves * avail_sms); @@ -414,14 +456,15 @@ struct ThreadblockSwizzleStreamK { int dp_tiles = output_tiles; // Number of data-parallel tiles int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles - // kGemm mode allows for SK load balancing + // Only kGemm mode allows for SK load balancing if (mode_ == GemmUniversalMode::kGemm) { - if (batch_count_ > 1) + int split_factor = batch_split_; + if (split_factor > 1) { // Split-K override dp_tiles = 0; - sk_blocks = output_tiles * batch_count_; + sk_blocks = output_tiles * split_factor; } else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled (avail_sms > 1)) // Plurality of SMs to load balance across @@ -462,24 +505,39 @@ struct ThreadblockSwizzleStreamK { sk_big_blocks_per_region = sk_big_blocks / sk_regions; sk_iters_per_region = sk_iters / sk_regions; - div_mod.sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block); - div_mod.sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1); - div_mod.sk_iters_per_region = FastDivmod(sk_iters_per_region); - div_mod.sk_blocks_per_region = FastDivmod(sk_blocks_per_region); - - // Separate reduction heuristic + // Use a separate reduction wave when all of: + // - Non-atomic reduction stratgy + // - The number of SK waves won't fully occupy the GPU (Otherwise we don't have + // a strong-scaling case for more parallel reduction) + // - More than three peers working on an SK tile. (This occurs when the ratio of + // SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks, + // e.g.:[partial-block | block | block | partial-block] ). With three or + // less peers, the two non-finishing SK-blocks are not expexted to contend. if ((kReductionStrategy == kMixed) && - (sk_blocks > 2 * sk_tiles)) // Use a separate reduction wave whenever we would have more than three - // peers working on an SK tile. (This occurs when the ratio of SK-blocks - // to SK-tiles > 2, as a single tile may be covered by four SK-blocks, - // e.g.:[partial-block | block | block | partial-block] ). With three or - // less peers, the two non-finishing SK-blocks are not expexted to contend. + (sk_waves < sm_occupancy) && + (sk_blocks > 2 * sk_tiles)) { - // Launch a reduction block every accumulator fragment in each SK-tile + // Launch a reduction block for every accumulator fragment in each SK-tile static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments; reduction_blocks = sk_tiles * kAccumulatorFragments; } + + // When we have a multi-occupancy kernel and at least two waves of active blocks (where + // at least one wave is SK blocks), we need to (1) dispatch at least four waves, and (2) + // remap the block indices so that we can reliably spread the SK blocks evenly across the + // device's first SM occupancy valence. Also see get_num_blocks() and get_block_idx(). + remap_block_indices = ( + (sm_occupancy > 1) && + (device_sms_ == avail_sms) && + (get_num_active_blocks() > avail_sms * 2)); + + // Initialize fast div/mod members related to SK + div_mod_sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block); + div_mod_sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1); + div_mod_sk_iters_per_region = FastDivmod(sk_iters_per_region); + div_mod_sk_regions = FastDivmod(sk_regions); + div_mod_sk_blocks_per_region = FastDivmod(sk_blocks_per_region); } // @@ -491,7 +549,7 @@ struct ThreadblockSwizzleStreamK { cutlass::gemm::GemmCoord tiled_cohort_shape( (tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM, (tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN, - batch_count_); + tiled_shape.k()); int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort; float cohort_efficiency = float(dp_blocks) / float(cohort_blocks); @@ -511,11 +569,12 @@ struct ThreadblockSwizzleStreamK { { sk_in_range = false; } + } // Decide if we're going to be doing cohort raster if (sk_in_range && - (dp_blocks >= gpu_occupancy) && + (dp_blocks >= gpu_occupancy * 2) && (cohort_efficiency > 0.85f)) { cohort_raster = true; @@ -537,92 +596,54 @@ struct ThreadblockSwizzleStreamK { } // Setup fast-div/mod for device-side usage - div_mod.tiled_shape_m = FastDivmod(tiled_shape.m()); - div_mod.tiled_shape_n = FastDivmod(tiled_shape.n()); - div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n()); - div_mod.iters_per_tile = FastDivmod(iters_per_tile); + div_mod_tiled_shape_m = FastDivmod(tiled_shape.m()); + div_mod_tiled_shape_n = FastDivmod(tiled_shape.n()); + div_mod_tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n()); + div_mod_iters_per_tile = FastDivmod(iters_per_tile); + } - - /// Constructor: *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) - template - ThreadblockSwizzleStreamK( - KernelTraits kernel_traits_, - GemmUniversalMode mode_, - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv2dProblemSize const &problem_size_, - GemmCoord tile_size_, - int batch_count_, - int sm_occupancy_, - int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance - int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs - int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles - : - ThreadblockSwizzleStreamK( - kernel_traits_, - mode_, - cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_), - tile_size_, - batch_count_, - sm_occupancy_, - avail_sms_, - dp_tiles_, - sk_blocks_) - {} - - - /// Constructor: *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC) - template - ThreadblockSwizzleStreamK( - KernelTraits kernel_traits_, - GemmUniversalMode mode_, - cutlass::conv::Operator conv_operator, - cutlass::conv::Conv3dProblemSize const &problem_size_, - GemmCoord tile_size_, - int batch_count_, - int sm_occupancy_, - int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance - int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs - int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles - : - ThreadblockSwizzleStreamK( - kernel_traits_, - mode_, - cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_), - tile_size_, - batch_count_, - sm_occupancy_, - avail_sms_, - dp_tiles_, - sk_blocks_) - {} - + /// Number of blocks performing useful work + int get_num_active_blocks() const + { + return (sk_waves * avail_sms) + dp_blocks + reduction_blocks; + } /// Obtains number of threadblocks per GEMM int get_num_blocks() const { - int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks; - - if (work_blocks <= avail_sms * 2) + int active_blocks = get_num_active_blocks(); + if (remap_block_indices) { - return work_blocks; + // Add padding blocks if we are performing remapping in order to dispatch a grid of at least four waves + return fast_max(active_blocks, avail_sms * 4); } - return fast_max(work_blocks, avail_sms * 4); + return active_blocks; } /// Obtains grid extents in CTAs dim3 get_grid_dims() const { - return dim3(get_num_blocks(), 1, tiled_shape.k()); + return dim3(get_num_blocks(), 1, batch_count); } +// Guards needed for PyCUTLASS library generation +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) + // // Device-side interface // + /// Proves to the compiler that val is warp-uniform + CUTLASS_DEVICE + int uniform(int val) const + { + return __shfl_sync(0xffffffff, val, 0); + } + /// Obtains number of threadblocks per GEMM CUTLASS_DEVICE int device_num_blocks() const @@ -634,9 +655,16 @@ struct ThreadblockSwizzleStreamK { CUTLASS_DEVICE int get_sk_tile_idx(int iter) const { - return div_mod.iters_per_tile.div(iter); + int tile_idx = div_mod_iters_per_tile.div(iter); + return uniform(tile_idx); } + /// Obtains the batch index + CUTLASS_DEVICE + int get_batch_idx() const + { + return RematerializeBlockIdxZ(); + } /// Obtains the calling threadblock's tiled coordinates for the given tile index CUTLASS_DEVICE @@ -644,12 +672,21 @@ struct ThreadblockSwizzleStreamK { { int m, n; + // row-major raster + div_mod_tiled_shape_n(m, n, tile_idx); + + if (tiled_shape().m() < tiled_shape().n()) + { + // column-major raster + div_mod_tiled_shape_m(n, m, tile_idx); + } + if (cohort_raster) { // tiled cohort raster int cohort_tile_idx = tile_idx / kCtasPerCohort; int cohort_grid_m, cohort_grid_n; - div_mod.tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx); + div_mod_tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx); int block_idx_cohort = tile_idx % kCtasPerCohort; int block_cohort_m = block_idx_cohort / kCohortCtasN; @@ -658,44 +695,46 @@ struct ThreadblockSwizzleStreamK { m = (cohort_grid_m * kCohortCtasM) + block_cohort_m; n = (cohort_grid_n * kCohortCtasN) + block_cohort_n; } - else if (tiled_shape.m() < tiled_shape.n()) - { - // column-major raster - div_mod.tiled_shape_m(n, m, tile_idx); - } - else - { - // row-major raster - div_mod.tiled_shape_n(m, n, tile_idx); - } - int block_idx_k = RematerializeBlockIdxZ(); - return GemmCoord{m, n, block_idx_k}; + return GemmCoord(m, n, get_batch_idx()); } + /// Obtains the calling threadblock's tiled coordinates for the given tile index (row-major rastorization) + CUTLASS_DEVICE + GemmCoord get_tile_offset_row_major(int tile_idx) const + { + // row-major raster + int m, n; + div_mod_tiled_shape_n(m, n, tile_idx); + return GemmCoord(m, n, get_batch_idx()); + } /// Obtains calling threadblock's linear threadblock index CUTLASS_DEVICE int get_block_idx() const { + int block_idx = RematerializeBlockIdxX(); + // Remap the block indices for the first two waves of thread blocks if // we have multi-occupancy and the grid constitutes four or more waves - - int block_idx = RematerializeBlockIdxX(); - int num_blocks = device_num_blocks(); - int dest_sm = block_idx / 2; - int dest_wave = block_idx % 2; - int remapped_block_idx = dest_sm + (dest_wave * avail_sms); - - if ((sm_occupancy > 1) && - (num_blocks >= avail_sms * 4) && - (block_idx < avail_sms * 2)) + if (remap_block_indices && (block_idx < avail_sms * 2)) { + int dest_sm = block_idx / 2; + int dest_wave = block_idx % 2; + int remapped_block_idx = dest_sm + (dest_wave * avail_sms); block_idx = remapped_block_idx; } - // Block-index is blockIdx.x for DP blocks - return block_idx; + // Remap block indices to interleave SK regions to limit intra-region waiting + if (block_idx < sk_regions() * sk_blocks_per_region()) + { + int block_in_region; + int region; + div_mod_sk_regions(block_in_region, region, block_idx); + block_idx = (region * sk_blocks_per_region()) + block_in_region; + } + + return uniform(block_idx); } @@ -705,19 +744,21 @@ struct ThreadblockSwizzleStreamK { { int region_idx; int iter_in_region; - div_mod.sk_iters_per_region(region_idx, iter_in_region, iter); + div_mod_sk_iters_per_region(region_idx, iter_in_region, iter); - int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block) + sk_big_blocks_per_region; // number of iterations in the region's big blocks + int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block()) + sk_big_blocks_per_region; // number of iterations in the region's big blocks int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal bocks - int big_block_idx_in_region = div_mod.sk_iters_per_big_block.div(iter_in_region); - int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod.sk_iters_per_normal_block.div(normal_block_iters); + int big_block_idx_in_region = div_mod_sk_iters_per_big_block.div(iter_in_region); + int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod_sk_iters_per_normal_block.div(normal_block_iters); int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ? big_block_idx_in_region : normal_block_idx_in_region; - return (sk_blocks_per_region * region_idx) + block_idx_in_region; + int owning_block_idx = (sk_blocks_per_region() * region_idx) + block_idx_in_region; + + return owning_block_idx; } /// Obtains iteration extends for the given SK block index @@ -729,12 +770,12 @@ struct ThreadblockSwizzleStreamK { { int region_idx; int block_idx_in_region; - div_mod.sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx); + div_mod_sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx); - block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block); + block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block()); // Adjust extents for the first "num_big_blocks" blocks that get one extra iteration - int block_iters = sk_iters_per_normal_block; + int block_iters = sk_iters_per_normal_block(); if (block_idx_in_region < sk_big_blocks_per_region) { // This is a +1 iteration block block_iter_begin += block_idx_in_region; @@ -756,10 +797,12 @@ struct ThreadblockSwizzleStreamK { return block_idx; } - int iter = tile_idx * iters_per_tile; + int iter = tile_idx * iters_per_tile(); return get_sk_block_idx(iter); } +#endif // defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 87777f5b..0aae8fce 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -50,9 +50,9 @@ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8): # def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ - swizzling_functor = SwizzlingFunctor.Identity8): +# swizzling_functor = SwizzlingFunctor.Identity8): # Use StreamK decomposition for basic GEMMs -# swizzling_functor = SwizzlingFunctor.StreamK): + swizzling_functor = SwizzlingFunctor.StreamK): if complex_transforms is None: complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] @@ -4600,6 +4600,7 @@ if __name__ == "__main__": parser.add_argument('--selected-kernel-list', type=str, default=None, required=False, help='Specify the output log file containing all enabled kernels in this build') parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels") + parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures") args = parser.parse_args()