streamk example and performance tuning (#760)
* streamk example and performance tuning * one missing file Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
a1046d49c1
commit
764b840d6f
@ -1,7 +1,7 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
||||
* Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||
* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||
* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
|
||||
@ -39,7 +39,7 @@ supported at each level of the execution model hierarchy.
|
||||
# What's New in CUTLASS 2.11
|
||||
|
||||
CUTLASS 2.11 is an update to CUTLASS adding:
|
||||
- Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
- [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||
- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths.
|
||||
- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel.
|
||||
- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||
@ -115,7 +115,7 @@ any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|NVIDIA A100|8.0|11.0|11.0|
|
||||
|NVIDIA A10 |8.6|11.1|11.1|
|
||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||
|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8|
|
||||
|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8; Mixed precision: 12.0|
|
||||
|
||||
# Documentation
|
||||
|
||||
|
||||
35
examples/47_ampere_gemm_universal_streamk/CMakeLists.txt
Normal file
35
examples/47_ampere_gemm_universal_streamk/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
47_ampere_gemm_universal_streamk
|
||||
ampere_gemm_universal_streamk.cu
|
||||
)
|
||||
|
||||
@ -0,0 +1,592 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/***************************************************************************************************
|
||||
Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
|
||||
"classic data-parallel" and "Split-K" decompositions.
|
||||
|
||||
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
|
||||
for Dense Matrix-Matrix Multiplication on the GPU" <todo: link>
|
||||
|
||||
Requires NVIDIA Ampere or newer device (SM80+).
|
||||
|
||||
- To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100)
|
||||
|
||||
cutlass$ sudo nvidia-smi -pm 1 -i 0
|
||||
|
||||
cutlass$ sudo nvidia-smi -i 0 -pl 400
|
||||
|
||||
cutlass$ sudo nvidia-smi -i 0 -lgc 1005
|
||||
|
||||
- Build and run:
|
||||
|
||||
cutlass$ mkdir build
|
||||
|
||||
cutlass$ cd build
|
||||
|
||||
cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80
|
||||
|
||||
cutlass/build$ make 47_ampere_gemm_universal_streamk
|
||||
|
||||
cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk
|
||||
|
||||
10000 timing iterations of 2048 x 2048 x 2048 matrix-matrix multiply
|
||||
|
||||
Basic data-parallel GEMM
|
||||
Disposition: Passed
|
||||
Avg runtime: 0.112633 ms
|
||||
GFLOPs: 152530
|
||||
|
||||
StreamK GEMM with default load-balancing
|
||||
Disposition: Passed
|
||||
Avg runtime: 0.0941929 ms
|
||||
GFLOPs: 182390
|
||||
Speedup vs Basic-DP: 1.196
|
||||
|
||||
StreamK emulating basic data-parallel GEMM
|
||||
Disposition: Passed
|
||||
Avg runtime: 0.113119 ms
|
||||
GFLOPs: 151875
|
||||
Speedup vs Basic-DP: 0.996
|
||||
|
||||
Basic split-K GEMM with tile-splitting factor 2
|
||||
Disposition: Passed
|
||||
Avg runtime: 0.104772 ms
|
||||
GFLOPs: 163973
|
||||
|
||||
StreamK emulating Split-K GEMM with tile-splitting factor 2
|
||||
Disposition: Passed
|
||||
Avg runtime: 0.105379 ms
|
||||
GFLOPs: 163029
|
||||
Speedup vs Basic-SplitK: 0.994
|
||||
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes)
|
||||
|
||||
// Multiply-accumulate blocking/pipelining details
|
||||
using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
|
||||
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
|
||||
|
||||
// Epilogue output operator
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementC, // Element type for C and D matrix operands
|
||||
AlignmentC, // Memory access granularity of C and D matrix in units of elements
|
||||
ElementAccumulator, // Element type from internal accumaccumulation
|
||||
ElementAccumulator>; // Data type used to compute linear combination
|
||||
|
||||
// Reference device GEMM implementation type
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
// Classic data-parallel device GEMM implementation type
|
||||
using DeviceGemmBasic = cutlass::gemm::device::GemmUniversal<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
NumStages,
|
||||
AlignmentA,
|
||||
AlignmentB>;
|
||||
|
||||
// StreamK device GEMM implementation type
|
||||
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal<
|
||||
ElementA, LayoutA,
|
||||
ElementB, LayoutB,
|
||||
ElementC, LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EpilogueOp,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference
|
||||
NumStages,
|
||||
AlignmentA,
|
||||
AlignmentB>;
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Result structure
|
||||
struct Result
|
||||
{
|
||||
double avg_runtime_ms;
|
||||
double gflops;
|
||||
cutlass::Status status;
|
||||
cudaError_t error;
|
||||
bool passed;
|
||||
|
||||
Result(
|
||||
double avg_runtime_ms = 0,
|
||||
double gflops = 0,
|
||||
cutlass::Status status = cutlass::Status::kSuccess,
|
||||
cudaError_t error = cudaSuccess)
|
||||
:
|
||||
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/// Command line options parsing
|
||||
struct Options
|
||||
{
|
||||
std::string command_name;
|
||||
bool help;
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
float alpha;
|
||||
float beta;
|
||||
int split_k_factor;
|
||||
int avail_sms;
|
||||
bool reference_check;
|
||||
int iterations;
|
||||
|
||||
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
|
||||
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_c;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
|
||||
cutlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
|
||||
|
||||
Options(std::string command_name) :
|
||||
command_name(command_name),
|
||||
help(false),
|
||||
problem_size({2048, 2048, 2048}),
|
||||
alpha(1.0f),
|
||||
beta(0.0f),
|
||||
split_k_factor(1),
|
||||
avail_sms(-1), // Number of device SMs to use is unlimited
|
||||
reference_check(true),
|
||||
iterations(10000)
|
||||
{}
|
||||
|
||||
bool valid() const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
void parse(int argc, char const **args)
|
||||
{
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||
cmd.get_cmd_line_argument("alpha", alpha);
|
||||
cmd.get_cmd_line_argument("beta", beta);
|
||||
cmd.get_cmd_line_argument("split", split_k_factor);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const
|
||||
{
|
||||
out
|
||||
<< "Performs a GEMM computation.\n"
|
||||
<< "\n"
|
||||
<< "Options:\n"
|
||||
<< "\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --m=<int> GEMM M dimension\n"
|
||||
<< " --n=<int> GEMM N dimension\n"
|
||||
<< " --k=<int> GEMM K dimension\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||
<< " --split=<int> Split-K factor to emulate\n\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Compute performance in GFLOP/s
|
||||
double gflops(double runtime_s) const
|
||||
{
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM evaluation
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options
|
||||
typename DeviceGemmBasic::Arguments args_from_options(
|
||||
const DeviceGemmBasic &device_gemm,
|
||||
const Options &options,
|
||||
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
|
||||
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_c,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_d)
|
||||
{
|
||||
return typename DeviceGemmBasic::Arguments(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
|
||||
options.problem_size, // problem_size
|
||||
options.split_k_factor, // batch count / splitk slices
|
||||
{ // epilogue parameters
|
||||
ElementAccumulator(options.alpha),
|
||||
ElementAccumulator(options.beta)
|
||||
},
|
||||
tensor_a.device_data(), // ptr_A
|
||||
tensor_b.device_data(), // ptr_B
|
||||
tensor_c.device_data(), // ptr_C
|
||||
tensor_d.device_data(), // ptr_D
|
||||
options.problem_size.mk().product(), // batch_stride_A
|
||||
options.problem_size.nk().product(), // batch_stride_B
|
||||
options.problem_size.mn().product(), // batch_stride_C
|
||||
options.problem_size.mn().product(), // batch_stride_D
|
||||
tensor_a.layout().stride(0), // stride_a
|
||||
tensor_b.layout().stride(0), // stride_b
|
||||
tensor_c.layout().stride(0), // stride_c
|
||||
tensor_d.layout().stride(0)); // stride_d
|
||||
}
|
||||
|
||||
/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
|
||||
typename DeviceGemmStreamK::Arguments args_from_options(
|
||||
const DeviceGemmStreamK &device_gemm,
|
||||
const Options &options,
|
||||
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
|
||||
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_c,
|
||||
cutlass::HostTensor<ElementC, LayoutC> &tensor_d)
|
||||
{
|
||||
return typename DeviceGemmStreamK::Arguments(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
|
||||
options.problem_size, // problem_size
|
||||
options.split_k_factor, // batch count / splitk slices
|
||||
{ // epilogue parameters
|
||||
ElementAccumulator(options.alpha),
|
||||
ElementAccumulator(options.beta)
|
||||
},
|
||||
tensor_a.device_data(), // ptr_A
|
||||
tensor_b.device_data(), // ptr_B
|
||||
tensor_c.device_data(), // ptr_C
|
||||
tensor_d.device_data(), // ptr_D
|
||||
options.problem_size.mk().product(), // batch_stride_A
|
||||
options.problem_size.nk().product(), // batch_stride_B
|
||||
options.problem_size.mn().product(), // batch_stride_C
|
||||
options.problem_size.mn().product(), // batch_stride_D
|
||||
tensor_a.layout().stride(0), // stride_a
|
||||
tensor_b.layout().stride(0), // stride_b
|
||||
tensor_c.layout().stride(0), // stride_c
|
||||
tensor_d.layout().stride(0), // stride_d
|
||||
options.avail_sms); // avail_sms
|
||||
}
|
||||
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename DeviceGemmT>
|
||||
Result run(std::string description, Options &options)
|
||||
{
|
||||
// Display test description
|
||||
std::cout << std::endl << description << std::endl;
|
||||
|
||||
// Zero-initialize test output matrix D
|
||||
cutlass::reference::host::TensorFill(options.tensor_d.host_view());
|
||||
options.tensor_d.sync_device();
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
DeviceGemmT device_gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
|
||||
auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check the problem size is supported or not
|
||||
CUTLASS_CHECK(device_gemm.can_implement(arguments));
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
|
||||
|
||||
// Correctness / Warmup iteration
|
||||
CUTLASS_CHECK(device_gemm());
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
options.tensor_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
Result result;
|
||||
result.passed = cutlass::reference::host::TensorEquals(
|
||||
options.tensor_d.host_view(),
|
||||
options.tensor_ref_d.host_view());
|
||||
|
||||
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
// Run profiling loop
|
||||
if (options.iterations > 0)
|
||||
{
|
||||
GpuTimer timer;
|
||||
timer.start();
|
||||
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||
CUTLASS_CHECK(device_gemm());
|
||||
}
|
||||
timer.stop();
|
||||
|
||||
// Compute average runtime and GFLOPs.
|
||||
float elapsed_ms = timer.elapsed_millis();
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||
|
||||
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||
}
|
||||
|
||||
if (!result.passed) {
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/// Program entrypoint
|
||||
int main(int argc, const char **argv)
|
||||
{
|
||||
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Current device must must have compute capability at least 80
|
||||
cudaDeviceProp props;
|
||||
int current_device_id;
|
||||
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||
if (!((props.major * 10 + props.minor) >= 80))
|
||||
{
|
||||
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||
<< std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Parse commandline options
|
||||
Options options("ampere_streamk_gemm");
|
||||
options.parse(argc, argv);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::cout <<
|
||||
options.iterations << " timing iterations of " <<
|
||||
options.problem_size.m() << " x " <<
|
||||
options.problem_size.n() << " x " <<
|
||||
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
|
||||
|
||||
if (!options.valid()) {
|
||||
std::cerr << "Invalid problem." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Initialize GEMM datasets
|
||||
//
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
|
||||
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
|
||||
|
||||
// Fill matrix A on host with uniform-random data [4, -4]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_a.host_view(),
|
||||
1,
|
||||
ElementA(2),
|
||||
ElementA(-2),
|
||||
0);
|
||||
|
||||
// Fill matrix B on host with uniform-random data [4, -4]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_b.host_view(),
|
||||
1,
|
||||
ElementB(2),
|
||||
ElementB(-2),
|
||||
0);
|
||||
|
||||
// Fill matrix C on host with uniform-random data [4, -4]
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
options.tensor_c.host_view(),
|
||||
1,
|
||||
ElementC(2),
|
||||
ElementC(-2),
|
||||
0);
|
||||
|
||||
|
||||
//
|
||||
// Compute reference output
|
||||
//
|
||||
|
||||
// Copy data from host to GPU
|
||||
options.tensor_a.sync_device();
|
||||
options.tensor_b.sync_device();
|
||||
options.tensor_c.sync_device();
|
||||
|
||||
// Zero-initialize reference output matrix D
|
||||
cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
|
||||
options.tensor_ref_d.sync_device();
|
||||
|
||||
// Create instantiation for device reference gemm kernel
|
||||
DeviceGemmReference gemm_reference;
|
||||
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
options.problem_size,
|
||||
ElementAccumulator(options.alpha),
|
||||
options.tensor_a.device_ref(),
|
||||
options.tensor_b.device_ref(),
|
||||
ElementAccumulator(options.beta),
|
||||
options.tensor_c.device_ref(),
|
||||
options.tensor_ref_d.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Copy output data from reference kernel to host for comparison
|
||||
options.tensor_ref_d.sync_host();
|
||||
|
||||
|
||||
//
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
// Test default operation
|
||||
if (options.split_k_factor == 1)
|
||||
{
|
||||
// Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
|
||||
Result basic_dp = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
|
||||
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
|
||||
|
||||
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
|
||||
|
||||
// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
|
||||
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
|
||||
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
|
||||
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
|
||||
|
||||
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
|
||||
|
||||
options.split_k_factor++; // Increment splitting factor for next evaluation
|
||||
|
||||
}
|
||||
|
||||
// Show that StreamK can emulate "Split-K" with a tile-splitting factor
|
||||
Result basic_splitk = run<DeviceGemmBasic>(
|
||||
std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
|
||||
options);
|
||||
|
||||
Result streamk_splitk = run<DeviceGemmStreamK>(
|
||||
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
|
||||
options);
|
||||
|
||||
printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -124,6 +124,7 @@ foreach(EXAMPLE
|
||||
43_ell_block_sparse_gemm
|
||||
45_dual_gemm
|
||||
46_depthwise_simt_conv2dfprop
|
||||
47_ampere_gemm_universal_streamk
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
|
||||
/**
|
||||
* Panic wrapper for unwinding CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
@ -12,6 +15,10 @@
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Panic wrapper for unwinding CUDA runtime errors
|
||||
*/
|
||||
#define CUDA_CHECK(status) \
|
||||
{ \
|
||||
cudaError_t error = status; \
|
||||
@ -21,3 +28,50 @@
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* GPU timer for recording the elapsed time across kernel(s) launched in GPU stream
|
||||
*/
|
||||
struct GpuTimer
|
||||
{
|
||||
cudaStream_t _stream_id;
|
||||
cudaEvent_t _start;
|
||||
cudaEvent_t _stop;
|
||||
|
||||
/// Constructor
|
||||
GpuTimer() : _stream_id(0)
|
||||
{
|
||||
CUDA_CHECK(cudaEventCreate(&_start));
|
||||
CUDA_CHECK(cudaEventCreate(&_stop));
|
||||
}
|
||||
|
||||
/// Destructor
|
||||
~GpuTimer()
|
||||
{
|
||||
CUDA_CHECK(cudaEventDestroy(_start));
|
||||
CUDA_CHECK(cudaEventDestroy(_stop));
|
||||
}
|
||||
|
||||
/// Start the timer for a given stream (defaults to the default stream)
|
||||
void start(cudaStream_t stream_id = 0)
|
||||
{
|
||||
_stream_id = stream_id;
|
||||
CUDA_CHECK(cudaEventRecord(_start, _stream_id));
|
||||
}
|
||||
|
||||
/// Stop the timer
|
||||
void stop()
|
||||
{
|
||||
CUDA_CHECK(cudaEventRecord(_stop, _stream_id));
|
||||
}
|
||||
|
||||
/// Return the elapsed time (in milliseconds)
|
||||
float elapsed_millis()
|
||||
{
|
||||
float elapsed = 0.0;
|
||||
CUDA_CHECK(cudaEventSynchronize(_stop));
|
||||
CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop));
|
||||
return elapsed;
|
||||
}
|
||||
};
|
||||
|
||||
@ -57,34 +57,25 @@ public:
|
||||
|
||||
protected:
|
||||
|
||||
/// Load flag, as a strong operation (int specialization)
|
||||
/// Load flag, as a strong acquire operation (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static int ld_strong(int *ptr)
|
||||
static int ld_acquire(int *ptr)
|
||||
{
|
||||
int state = 0;
|
||||
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("ld.global.relaxed.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
|
||||
// Acquire pattern using acquire modifier
|
||||
asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||
|
||||
#else
|
||||
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
/// Store flag, as a strong operation (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static void st_strong(int *ptr, int val)
|
||||
{
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("st.global.relaxed.gpu.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
#else
|
||||
asm volatile ("st.cg.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
}
|
||||
|
||||
|
||||
/// Reduce into flag, with release pattern (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
@ -92,11 +83,16 @@ protected:
|
||||
{
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
#if (__CUDA_ARCH__ >= 700)
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
asm volatile ("red.release.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
/// SM70 and newer use memory consistency qualifiers
|
||||
|
||||
// Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data
|
||||
// that was weakly-written by other threads prior to the last syncthreads)
|
||||
asm volatile ("fence.acq_rel.gpu;\n");
|
||||
asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||
|
||||
#else
|
||||
__threadfence();
|
||||
atomicAdd(ptr, val);
|
||||
__threadfence();
|
||||
atomicAdd(ptr, val);
|
||||
#endif // (__CUDA_ARCH__ >= 700)
|
||||
#endif
|
||||
}
|
||||
@ -115,7 +111,7 @@ public:
|
||||
{
|
||||
// Spin-loop
|
||||
#pragma unroll 1
|
||||
while(ld_strong(flag_ptr) < count) {}
|
||||
while(ld_acquire(flag_ptr) < count) {}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -133,9 +129,8 @@ public:
|
||||
{
|
||||
// Spin-loop
|
||||
#pragma unroll 1
|
||||
while(ld_strong(flag_ptr) != val) {}
|
||||
while(ld_acquire(flag_ptr) != val) {}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
@ -166,7 +161,8 @@ public:
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (thread_idx == 0) {
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
red_release(flag_ptr, 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -124,7 +124,7 @@ public:
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
int batch_count; // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor
|
||||
|
||||
typename EpilogueOutputOp::Params epilogue;
|
||||
|
||||
@ -148,7 +148,7 @@ public:
|
||||
typename LayoutC::Stride::LongIndex ldc;
|
||||
typename LayoutC::Stride::LongIndex ldd;
|
||||
|
||||
int sm_limit; /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
||||
int avail_sms; /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
|
||||
|
||||
|
||||
//
|
||||
@ -159,15 +159,18 @@ public:
|
||||
Arguments():
|
||||
mode(GemmUniversalMode::kGemm),
|
||||
batch_count(1),
|
||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
||||
sm_limit(-1)
|
||||
ptr_A(nullptr),
|
||||
ptr_B(nullptr),
|
||||
ptr_C(nullptr),
|
||||
ptr_D(nullptr),
|
||||
avail_sms(-1)
|
||||
{}
|
||||
|
||||
/// Constructor
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
|
||||
typename EpilogueOutputOp::Params epilogue,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
@ -181,15 +184,15 @@ public:
|
||||
typename LayoutB::Stride stride_b,
|
||||
typename LayoutC::Stride stride_c,
|
||||
typename LayoutC::Stride stride_d,
|
||||
int sm_limit = -1 /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
||||
int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
batch_count(batch_split),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), sm_limit(sm_limit)
|
||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), avail_sms(avail_sms)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size);
|
||||
}
|
||||
@ -198,7 +201,7 @@ public:
|
||||
Arguments(
|
||||
GemmUniversalMode mode,
|
||||
GemmCoord problem_size,
|
||||
int batch_count,
|
||||
int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
|
||||
typename EpilogueOutputOp::Params epilogue,
|
||||
void const * ptr_A,
|
||||
void const * ptr_B,
|
||||
@ -212,15 +215,15 @@ public:
|
||||
typename LayoutB::Stride::LongIndex ldb,
|
||||
typename LayoutC::Stride::LongIndex ldc,
|
||||
typename LayoutC::Stride::LongIndex ldd,
|
||||
int sm_limit = -1 /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
||||
int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
|
||||
):
|
||||
mode(mode),
|
||||
problem_size(problem_size),
|
||||
batch_count(batch_count),
|
||||
batch_count(batch_split),
|
||||
epilogue(epilogue),
|
||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), sm_limit(sm_limit)
|
||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), avail_sms(avail_sms)
|
||||
{
|
||||
stride_a = make_Coord(lda);
|
||||
stride_b = make_Coord(ldb);
|
||||
@ -254,29 +257,36 @@ public:
|
||||
// Data members
|
||||
//
|
||||
|
||||
ThreadblockSwizzle block_mapping;
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
|
||||
void * ptr_A;
|
||||
void * ptr_B;
|
||||
void * ptr_C;
|
||||
void * ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
|
||||
ThreadblockSwizzle block_mapping;
|
||||
|
||||
bool quick_dp;
|
||||
|
||||
void *barrier_workspace;
|
||||
void *partials_workspace;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
void * ptr_D;
|
||||
void * ptr_C;
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
|
||||
int64_t batch_stride_D;
|
||||
int64_t batch_stride_C;
|
||||
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
@ -295,7 +305,7 @@ public:
|
||||
{
|
||||
// For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction,
|
||||
// each reduction block needs its own synchronization flag.
|
||||
int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region;
|
||||
int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region();
|
||||
int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks);
|
||||
|
||||
return cacheline_align_up(sizeof(typename Barrier::T) * num_flags);
|
||||
@ -304,7 +314,7 @@ public:
|
||||
/// Get the workspace size needed for intermediate partial sums
|
||||
size_t get_partials_workspace_size() const
|
||||
{
|
||||
int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region;
|
||||
int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region();
|
||||
return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks);
|
||||
}
|
||||
|
||||
@ -343,9 +353,9 @@ public:
|
||||
partials_workspace(nullptr)
|
||||
{
|
||||
// Number of SMs to make available for StreamK decomposition
|
||||
int avail_sms = (args.sm_limit == -1) ?
|
||||
int avail_sms = (args.avail_sms == -1) ?
|
||||
device_sms :
|
||||
fast_min(args.sm_limit, device_sms);
|
||||
fast_min(args.avail_sms, device_sms);
|
||||
|
||||
// Initialize the block mapping structure
|
||||
block_mapping = ThreadblockSwizzle(
|
||||
@ -355,7 +365,15 @@ public:
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||
args.batch_count,
|
||||
sm_occupancy,
|
||||
device_sms,
|
||||
avail_sms);
|
||||
|
||||
quick_dp =
|
||||
(block_mapping.sk_waves == 0) &&
|
||||
(mode == GemmUniversalMode::kGemm) &&
|
||||
!block_mapping.cohort_raster &&
|
||||
!EpilogueOutputOp(output_op).is_source_needed();
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -426,7 +444,7 @@ public:
|
||||
/// Returns the GEMM volume in thread block tiles
|
||||
cutlass::gemm::GemmCoord get_tiled_shape() const
|
||||
{
|
||||
return block_mapping.tiled_shape;
|
||||
return block_mapping.tiled_shape();
|
||||
}
|
||||
|
||||
|
||||
@ -533,9 +551,6 @@ protected:
|
||||
/// ID of each thread within a warp
|
||||
int lane_idx;
|
||||
|
||||
/// Block index
|
||||
int block_idx;
|
||||
|
||||
/// Threadblock scoped epilogue
|
||||
Epilogue epilogue;
|
||||
|
||||
@ -640,16 +655,18 @@ protected:
|
||||
|
||||
/// Iterator for fetching tile fragments from A
|
||||
CUTLASS_DEVICE
|
||||
typename Mma::IteratorA init_iterator_A(TileWorkDesc &tile_work)
|
||||
typename Mma::IteratorA init_iterator_A(
|
||||
TileWorkDesc &tile_work,
|
||||
GemmUniversalMode mode)
|
||||
{
|
||||
// The input A matrix
|
||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||
|
||||
// Update input pointers based on batched/array mode
|
||||
if (params.mode == GemmUniversalMode::kBatched) {
|
||||
if (mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A;
|
||||
}
|
||||
if (params.mode == GemmUniversalMode::kArray) {
|
||||
if (mode == GemmUniversalMode::kArray) {
|
||||
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[tile_work.tiled_coord.k()];
|
||||
}
|
||||
|
||||
@ -667,16 +684,18 @@ protected:
|
||||
|
||||
/// Iterator for fetching tile fragments from B
|
||||
CUTLASS_DEVICE
|
||||
typename Mma::IteratorB init_iterator_B(TileWorkDesc &tile_work)
|
||||
typename Mma::IteratorB init_iterator_B(
|
||||
TileWorkDesc &tile_work,
|
||||
GemmUniversalMode mode)
|
||||
{
|
||||
// The input B matrix
|
||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||
|
||||
// Update input pointers based on batched/array mode
|
||||
if (params.mode == GemmUniversalMode::kBatched) {
|
||||
if (mode == GemmUniversalMode::kBatched) {
|
||||
ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B;
|
||||
}
|
||||
if (params.mode == GemmUniversalMode::kArray) {
|
||||
if (mode == GemmUniversalMode::kArray) {
|
||||
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[tile_work.tiled_coord.k()];
|
||||
}
|
||||
|
||||
@ -700,10 +719,10 @@ protected:
|
||||
tile_work.tile_idx = tile_idx;
|
||||
|
||||
// The first global-scoped MAC-iteration this threadblock will perform for this tile
|
||||
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile;
|
||||
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
|
||||
|
||||
// The number of MAC-iterations this threadblock will perform for this tile
|
||||
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile;
|
||||
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
|
||||
|
||||
// The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile
|
||||
tile_work.k_begin = 0;
|
||||
@ -727,7 +746,7 @@ protected:
|
||||
tile_work.tile_idx = tile_idx;
|
||||
|
||||
// The first global-scoped MAC-iteration for this tile
|
||||
int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile;
|
||||
int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile();
|
||||
|
||||
// The first global-scoped MAC-iteration this threadblock will perform for this tile
|
||||
tile_work.iter_begin = max(block_iter_begin, tile_iter_begin);
|
||||
@ -756,7 +775,10 @@ protected:
|
||||
|
||||
/// Share accumulators with peers
|
||||
CUTLASS_DEVICE
|
||||
void share_accumulators(AccumulatorTile const &accumulator_tile, int first_block_idx)
|
||||
void share_accumulators(
|
||||
AccumulatorTile const &accumulator_tile,
|
||||
int block_idx,
|
||||
int first_block_idx)
|
||||
{
|
||||
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
|
||||
|
||||
@ -795,6 +817,7 @@ protected:
|
||||
CUTLASS_DEVICE
|
||||
void acquire_accumulators(
|
||||
AccumulatorTile &accumulator_tile,
|
||||
int block_idx,
|
||||
int first_block_idx)
|
||||
{
|
||||
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
|
||||
@ -868,8 +891,8 @@ protected:
|
||||
reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments;
|
||||
reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments;
|
||||
|
||||
int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile;
|
||||
int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile - 1;
|
||||
int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile();
|
||||
int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1;
|
||||
|
||||
peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first);
|
||||
peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last);
|
||||
@ -895,16 +918,6 @@ protected:
|
||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
|
||||
// Update pointers for batched/array mode(s)
|
||||
if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C += tiled_coord.k() * params.batch_stride_C;
|
||||
ptr_D += tiled_coord.k() * params.batch_stride_D;
|
||||
}
|
||||
if (params.mode == GemmUniversalMode::kArray) {
|
||||
ptr_C = static_cast<ElementC * const *>(params.ptr_C)[tiled_coord.k()];
|
||||
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[tiled_coord.k()];
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params.params_C,
|
||||
@ -936,12 +949,13 @@ protected:
|
||||
CUTLASS_DEVICE
|
||||
void process_tile(
|
||||
TileWorkDesc tile_work,
|
||||
int block_idx,
|
||||
int dp_start_block_idx,
|
||||
int block_iter_begin)
|
||||
{
|
||||
// Initialize input iterators
|
||||
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work);
|
||||
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work);
|
||||
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
|
||||
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);
|
||||
|
||||
// Initialize accumulators
|
||||
AccumulatorTile accumulator_tile;
|
||||
@ -968,7 +982,7 @@ protected:
|
||||
|
||||
if (!tile_work.tile_finished(params)) {
|
||||
// Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace
|
||||
share_accumulators(accumulator_tile, first_block_idx);
|
||||
share_accumulators(accumulator_tile, block_idx, first_block_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -976,7 +990,7 @@ protected:
|
||||
if (!tile_work.tile_started())
|
||||
{
|
||||
// A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks
|
||||
acquire_accumulators(accumulator_tile, first_block_idx);
|
||||
acquire_accumulators(accumulator_tile, block_idx, first_block_idx);
|
||||
}
|
||||
|
||||
do_epilogue(tile_work, accumulator_tile);
|
||||
@ -1008,11 +1022,12 @@ protected:
|
||||
// Initialize block's iteration range
|
||||
int tile_idx, block_iter_begin, block_iters_remaining;
|
||||
|
||||
int sk_padding_start_block_idx = params.block_mapping.sk_regions * params.block_mapping.sk_blocks_per_region;
|
||||
int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();
|
||||
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
|
||||
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
|
||||
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;
|
||||
|
||||
int block_idx = params.block_mapping.get_block_idx();
|
||||
if (block_idx < sk_padding_start_block_idx)
|
||||
{
|
||||
// This is a SK block
|
||||
@ -1044,8 +1059,9 @@ protected:
|
||||
}
|
||||
|
||||
block_iter_begin = 0;
|
||||
block_iters_remaining = params.block_mapping.iters_per_tile * tile_allottment;
|
||||
block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;
|
||||
}
|
||||
|
||||
else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) &&
|
||||
(block_idx < grid_padding_start_block_idx))
|
||||
{
|
||||
@ -1072,8 +1088,8 @@ protected:
|
||||
|
||||
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
|
||||
if ((tile_idx < params.block_mapping.sk_tiles) ||
|
||||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape.m()) ||
|
||||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape.n()))
|
||||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
|
||||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
|
||||
{
|
||||
break;
|
||||
}
|
||||
@ -1084,7 +1100,7 @@ protected:
|
||||
}
|
||||
|
||||
// Perform this block's share of work for this tile
|
||||
process_tile(tile_work, dp_start_block_idx, block_iter_begin);
|
||||
process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin);
|
||||
|
||||
// Update remaining work for this block
|
||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||
@ -1110,6 +1126,64 @@ protected:
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Executes one DP-only GEMM
|
||||
CUTLASS_DEVICE
|
||||
void gemm_dp()
|
||||
{
|
||||
int block_idx = blockIdx.x;
|
||||
int tile_idx = block_idx;
|
||||
|
||||
TileWorkDesc tile_work;
|
||||
tile_work.tile_idx = tile_idx;
|
||||
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
|
||||
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
|
||||
tile_work.k_begin = 0;
|
||||
tile_work.k_end = params.block_mapping.problem_size.k();
|
||||
tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx);
|
||||
|
||||
// Initialize input iterators
|
||||
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
|
||||
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);
|
||||
|
||||
// Initialize accumulators
|
||||
AccumulatorTile accumulator_tile;
|
||||
accumulator_tile.clear();
|
||||
|
||||
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
||||
Mma mma(
|
||||
shared_storage.main_loop,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
|
||||
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
|
||||
// Location of this tile in item-coords
|
||||
MatrixCoord threadblock_item_begin(
|
||||
tile_work.tiled_coord.m() * Mma::Shape::kM,
|
||||
tile_work.tiled_coord.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
ptr_D,
|
||||
params.block_mapping.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_item_begin);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(
|
||||
EpilogueOutputOp(params.output_op),
|
||||
iterator_D,
|
||||
accumulator_tile);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
@ -1138,7 +1212,6 @@ public:
|
||||
thread_idx(threadIdx.x),
|
||||
warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
lane_idx(threadIdx.x % 32),
|
||||
block_idx(params.block_mapping.get_block_idx()),
|
||||
epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
@ -1151,7 +1224,17 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void operator()()
|
||||
{
|
||||
// Do the GEMM
|
||||
#if (__CUDACC_VER_MAJOR__ > 10)
|
||||
if (params.quick_dp)
|
||||
{
|
||||
// Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray
|
||||
// modes will only be launched using a data-parallel configurations)
|
||||
gemm_dp();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Generic SK code path
|
||||
gemm();
|
||||
|
||||
}
|
||||
|
||||
@ -115,28 +115,21 @@ struct ThreadblockSwizzleStreamK {
|
||||
// Member state
|
||||
//
|
||||
|
||||
|
||||
/// The 3D value-extents of the GEMM computation volume (m,n,k)
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// The 2D tile-extents of the output matrix (m,n)
|
||||
GemmCoord tiled_shape;
|
||||
/// Div/mod accelerators
|
||||
FastDivmod div_mod_tiled_shape_m;
|
||||
FastDivmod div_mod_tiled_shape_n;
|
||||
FastDivmod div_mod_tiled_cohort_shape_n;
|
||||
FastDivmod div_mod_iters_per_tile;
|
||||
|
||||
/// Number of iterations per output tile
|
||||
int iters_per_tile;
|
||||
/// Whether to perform cohort CTA rasterization
|
||||
bool cohort_raster;
|
||||
|
||||
/// Number of reduction blocks in the grid
|
||||
int reduction_blocks;
|
||||
|
||||
int dp_blocks; /// Number of data-parallel thread blocks in the grid
|
||||
int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce
|
||||
|
||||
int sk_tiles;
|
||||
int sk_regions;
|
||||
int sk_blocks_per_region;
|
||||
int sk_big_blocks_per_region;
|
||||
int sk_iters_per_region;
|
||||
int sk_iters_per_normal_block; /// Number of iterations for normal SK-blocks
|
||||
int sk_waves; /// Number of SK waves in the grid
|
||||
// Whether to pad and remap block indices
|
||||
bool remap_block_indices;
|
||||
|
||||
/// CTA occupancy per SM
|
||||
int sm_occupancy;
|
||||
@ -144,21 +137,26 @@ struct ThreadblockSwizzleStreamK {
|
||||
/// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size)
|
||||
int avail_sms;
|
||||
|
||||
/// Whether to perform cohort CTA rasterization
|
||||
bool cohort_raster;
|
||||
int dp_blocks; /// Number of data-parallel thread blocks in the grid
|
||||
int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce
|
||||
|
||||
/// Number of reduction blocks in the grid
|
||||
int reduction_blocks;
|
||||
|
||||
int sk_waves;
|
||||
int sk_tiles;
|
||||
int sk_big_blocks_per_region;
|
||||
int sk_iters_per_region;
|
||||
|
||||
/// Div/mod accelerators
|
||||
struct
|
||||
{
|
||||
FastDivmod tiled_shape_m;
|
||||
FastDivmod tiled_shape_n;
|
||||
FastDivmod tiled_cohort_shape_n;
|
||||
FastDivmod iters_per_tile;
|
||||
FastDivmod sk_iters_per_normal_block;
|
||||
FastDivmod sk_iters_per_big_block;
|
||||
FastDivmod sk_iters_per_region;
|
||||
FastDivmod sk_blocks_per_region;
|
||||
} div_mod;
|
||||
FastDivmod div_mod_sk_iters_per_normal_block;
|
||||
FastDivmod div_mod_sk_iters_per_big_block;
|
||||
FastDivmod div_mod_sk_iters_per_region;
|
||||
FastDivmod div_mod_sk_regions; //!! used in block map
|
||||
FastDivmod div_mod_sk_blocks_per_region; //!! used in block map
|
||||
|
||||
/// The batch count
|
||||
int batch_count;
|
||||
|
||||
|
||||
//
|
||||
@ -169,6 +167,43 @@ struct ThreadblockSwizzleStreamK {
|
||||
CUTLASS_HOST_DEVICE
|
||||
ThreadblockSwizzleStreamK() {}
|
||||
|
||||
/// Returns the GEMM volume in thread block tiles
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord tiled_shape() const
|
||||
{
|
||||
return GemmCoord(
|
||||
static_cast<int>(div_mod_tiled_shape_m),
|
||||
static_cast<int>(div_mod_tiled_shape_n),
|
||||
batch_count);
|
||||
}
|
||||
|
||||
/// Number of iterations per output tile
|
||||
CUTLASS_HOST_DEVICE
|
||||
int iters_per_tile() const
|
||||
{
|
||||
return static_cast<int>(div_mod_iters_per_tile);
|
||||
}
|
||||
|
||||
/// Number of iterations for normal SK-blocks
|
||||
CUTLASS_HOST_DEVICE
|
||||
int sk_iters_per_normal_block() const
|
||||
{
|
||||
return static_cast<int>(div_mod_sk_iters_per_normal_block);
|
||||
}
|
||||
|
||||
/// Number of SK regions
|
||||
CUTLASS_HOST_DEVICE
|
||||
int sk_regions() const
|
||||
{
|
||||
return static_cast<int>(div_mod_sk_regions);
|
||||
}
|
||||
|
||||
/// Number of SK blocks per region (splitting factor)
|
||||
CUTLASS_HOST_DEVICE
|
||||
int sk_blocks_per_region() const
|
||||
{
|
||||
return static_cast<int>(div_mod_sk_blocks_per_region);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
@ -179,26 +214,27 @@ struct ThreadblockSwizzleStreamK {
|
||||
void Print()
|
||||
{
|
||||
#ifndef __CUDA_ARCH__
|
||||
int tiles = tiled_shape.m() * tiled_shape.n();
|
||||
auto tiles = tiled_shape().mn().product();
|
||||
std::cout <<
|
||||
"problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" <<
|
||||
", reduction_blocks: " << reduction_blocks <<
|
||||
", dp_blocks: " << dp_blocks <<
|
||||
", sk_blocks_per_region: " << sk_blocks_per_region <<
|
||||
", sk_regions: " << sk_regions <<
|
||||
", sk_waves: " << sk_waves <<
|
||||
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
|
||||
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
|
||||
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
|
||||
", tiled_shape: (" << tiled_shape.m() << "," << tiled_shape.n() << ")" <<
|
||||
", tiled_shape: (" << tiled_shape().m() << "," << tiled_shape().n() << ")" <<
|
||||
", tiles: " << tiles <<
|
||||
", iters_per_tile: " << iters_per_tile <<
|
||||
", dp_tiles: " << tiles - sk_tiles <<
|
||||
", sk_tiles: " << sk_tiles <<
|
||||
", avail_sms: " << avail_sms <<
|
||||
", iters_per_tile: " << iters_per_tile() <<
|
||||
", reduction_blocks: " << reduction_blocks <<
|
||||
", dp_blocks: " << dp_blocks <<
|
||||
", dp_waves: " << dp_blocks / avail_sms <<
|
||||
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
|
||||
", sk_blocks_per_region: " << sk_blocks_per_region() <<
|
||||
", sk_regions: " << sk_regions() <<
|
||||
", sk_waves: " << sk_waves <<
|
||||
", sk_iters_per_normal_block: " << sk_iters_per_normal_block() <<
|
||||
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
|
||||
", remap_block_indices: " << remap_block_indices <<
|
||||
", cohort_raster: " << cohort_raster <<
|
||||
", sm_occupancy: " << sm_occupancy <<
|
||||
", avail_sms: " << avail_sms <<
|
||||
", cohort_raster: " << cohort_raster <<
|
||||
", num_blocks: " << get_num_blocks() <<
|
||||
"\n\n";
|
||||
#endif
|
||||
@ -368,30 +404,37 @@ struct ThreadblockSwizzleStreamK {
|
||||
GemmUniversalMode const mode_,
|
||||
GemmCoord const problem_size_,
|
||||
GemmCoord const tile_size_,
|
||||
int const batch_count_, /// Batch count (when mode_ == GemmUniversalMode::kBatched) or split-K-override splitting factor (when mode_ == GemmUniversalMode::kGemm)
|
||||
int const batch_split_, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
|
||||
int const sm_occupancy_,
|
||||
int const avail_sms_)
|
||||
int const device_sms_,
|
||||
int const avail_sms_) /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
|
||||
:
|
||||
problem_size(problem_size_),
|
||||
tiled_shape(
|
||||
(problem_size.m() + tile_size_.m() - 1) / tile_size_.m(),
|
||||
(problem_size.n() + tile_size_.n() - 1) / tile_size_.n(),
|
||||
(mode_ == GemmUniversalMode::kBatched) ? batch_count_ : 1),
|
||||
iters_per_tile((problem_size.k() + tile_size_.k() - 1) / tile_size_.k()),
|
||||
batch_count((mode_ == GemmUniversalMode::kBatched) ? batch_split_ : 1),
|
||||
reduction_blocks(0),
|
||||
dp_blocks(0),
|
||||
dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks
|
||||
sk_tiles(0),
|
||||
sk_regions(1), // Default: a single region of iteration space (across all SK tiles)
|
||||
sk_blocks_per_region(0),
|
||||
sk_big_blocks_per_region(0),
|
||||
sk_iters_per_region(0),
|
||||
sk_iters_per_normal_block(0),
|
||||
sk_waves(0),
|
||||
sm_occupancy(sm_occupancy_),
|
||||
remap_block_indices(false),
|
||||
avail_sms(fast_max(1, avail_sms_)),
|
||||
cohort_raster(false)
|
||||
{
|
||||
int gpu_occupancy = device_sms_ * sm_occupancy;
|
||||
int iters_per_tile = (problem_size.k() + tile_size_.k() - 1) / tile_size_.k();
|
||||
int sk_iters_per_normal_block = 0;
|
||||
|
||||
int sk_regions = 1; // Default: a single region of iteration space (across all SK tiles)
|
||||
int sk_blocks_per_region = 0;
|
||||
|
||||
GemmCoord tiled_shape(
|
||||
(problem_size.m() + tile_size_.m() - 1) / tile_size_.m(),
|
||||
(problem_size.n() + tile_size_.n() - 1) / tile_size_.n(),
|
||||
batch_count);
|
||||
|
||||
size_t problem_bytes =
|
||||
(sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) +
|
||||
(sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) +
|
||||
@ -401,7 +444,6 @@ struct ThreadblockSwizzleStreamK {
|
||||
|
||||
float flops_per_byte = float(problem_flops) / float(problem_bytes);
|
||||
|
||||
int gpu_occupancy = avail_sms * sm_occupancy;
|
||||
int output_tiles = tiled_shape.m() * tiled_shape.n();
|
||||
int waves = (output_tiles + avail_sms - 1) / avail_sms;
|
||||
float dp_efficiency = float(output_tiles) / float(waves * avail_sms);
|
||||
@ -414,14 +456,15 @@ struct ThreadblockSwizzleStreamK {
|
||||
int dp_tiles = output_tiles; // Number of data-parallel tiles
|
||||
int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles
|
||||
|
||||
// kGemm mode allows for SK load balancing
|
||||
// Only kGemm mode allows for SK load balancing
|
||||
if (mode_ == GemmUniversalMode::kGemm)
|
||||
{
|
||||
if (batch_count_ > 1)
|
||||
int split_factor = batch_split_;
|
||||
if (split_factor > 1)
|
||||
{
|
||||
// Split-K override
|
||||
dp_tiles = 0;
|
||||
sk_blocks = output_tiles * batch_count_;
|
||||
sk_blocks = output_tiles * split_factor;
|
||||
}
|
||||
else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled
|
||||
(avail_sms > 1)) // Plurality of SMs to load balance across
|
||||
@ -462,24 +505,39 @@ struct ThreadblockSwizzleStreamK {
|
||||
sk_big_blocks_per_region = sk_big_blocks / sk_regions;
|
||||
sk_iters_per_region = sk_iters / sk_regions;
|
||||
|
||||
div_mod.sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block);
|
||||
div_mod.sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1);
|
||||
div_mod.sk_iters_per_region = FastDivmod(sk_iters_per_region);
|
||||
div_mod.sk_blocks_per_region = FastDivmod(sk_blocks_per_region);
|
||||
|
||||
// Separate reduction heuristic
|
||||
// Use a separate reduction wave when all of:
|
||||
// - Non-atomic reduction stratgy
|
||||
// - The number of SK waves won't fully occupy the GPU (Otherwise we don't have
|
||||
// a strong-scaling case for more parallel reduction)
|
||||
// - More than three peers working on an SK tile. (This occurs when the ratio of
|
||||
// SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks,
|
||||
// e.g.:[partial-block | block | block | partial-block] ). With three or
|
||||
// less peers, the two non-finishing SK-blocks are not expexted to contend.
|
||||
if ((kReductionStrategy == kMixed) &&
|
||||
(sk_blocks > 2 * sk_tiles)) // Use a separate reduction wave whenever we would have more than three
|
||||
// peers working on an SK tile. (This occurs when the ratio of SK-blocks
|
||||
// to SK-tiles > 2, as a single tile may be covered by four SK-blocks,
|
||||
// e.g.:[partial-block | block | block | partial-block] ). With three or
|
||||
// less peers, the two non-finishing SK-blocks are not expexted to contend.
|
||||
(sk_waves < sm_occupancy) &&
|
||||
(sk_blocks > 2 * sk_tiles))
|
||||
{
|
||||
// Launch a reduction block every accumulator fragment in each SK-tile
|
||||
// Launch a reduction block for every accumulator fragment in each SK-tile
|
||||
static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments;
|
||||
reduction_blocks = sk_tiles * kAccumulatorFragments;
|
||||
|
||||
}
|
||||
|
||||
// When we have a multi-occupancy kernel and at least two waves of active blocks (where
|
||||
// at least one wave is SK blocks), we need to (1) dispatch at least four waves, and (2)
|
||||
// remap the block indices so that we can reliably spread the SK blocks evenly across the
|
||||
// device's first SM occupancy valence. Also see get_num_blocks() and get_block_idx().
|
||||
remap_block_indices = (
|
||||
(sm_occupancy > 1) &&
|
||||
(device_sms_ == avail_sms) &&
|
||||
(get_num_active_blocks() > avail_sms * 2));
|
||||
|
||||
// Initialize fast div/mod members related to SK
|
||||
div_mod_sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block);
|
||||
div_mod_sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1);
|
||||
div_mod_sk_iters_per_region = FastDivmod(sk_iters_per_region);
|
||||
div_mod_sk_regions = FastDivmod(sk_regions);
|
||||
div_mod_sk_blocks_per_region = FastDivmod(sk_blocks_per_region);
|
||||
}
|
||||
|
||||
//
|
||||
@ -491,7 +549,7 @@ struct ThreadblockSwizzleStreamK {
|
||||
cutlass::gemm::GemmCoord tiled_cohort_shape(
|
||||
(tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM,
|
||||
(tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN,
|
||||
batch_count_);
|
||||
tiled_shape.k());
|
||||
int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort;
|
||||
float cohort_efficiency = float(dp_blocks) / float(cohort_blocks);
|
||||
|
||||
@ -511,11 +569,12 @@ struct ThreadblockSwizzleStreamK {
|
||||
{
|
||||
sk_in_range = false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Decide if we're going to be doing cohort raster
|
||||
if (sk_in_range &&
|
||||
(dp_blocks >= gpu_occupancy) &&
|
||||
(dp_blocks >= gpu_occupancy * 2) &&
|
||||
(cohort_efficiency > 0.85f))
|
||||
{
|
||||
cohort_raster = true;
|
||||
@ -537,92 +596,54 @@ struct ThreadblockSwizzleStreamK {
|
||||
}
|
||||
|
||||
// Setup fast-div/mod for device-side usage
|
||||
div_mod.tiled_shape_m = FastDivmod(tiled_shape.m());
|
||||
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
|
||||
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
|
||||
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
|
||||
div_mod_tiled_shape_m = FastDivmod(tiled_shape.m());
|
||||
div_mod_tiled_shape_n = FastDivmod(tiled_shape.n());
|
||||
div_mod_tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
|
||||
div_mod_iters_per_tile = FastDivmod(iters_per_tile);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Constructor: *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC)
|
||||
template <typename GemmKernel>
|
||||
ThreadblockSwizzleStreamK(
|
||||
KernelTraits<GemmKernel> kernel_traits_,
|
||||
GemmUniversalMode mode_,
|
||||
cutlass::conv::Operator conv_operator,
|
||||
cutlass::conv::Conv2dProblemSize const &problem_size_,
|
||||
GemmCoord tile_size_,
|
||||
int batch_count_,
|
||||
int sm_occupancy_,
|
||||
int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
||||
int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs
|
||||
int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles
|
||||
:
|
||||
ThreadblockSwizzleStreamK(
|
||||
kernel_traits_,
|
||||
mode_,
|
||||
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_),
|
||||
tile_size_,
|
||||
batch_count_,
|
||||
sm_occupancy_,
|
||||
avail_sms_,
|
||||
dp_tiles_,
|
||||
sk_blocks_)
|
||||
{}
|
||||
|
||||
|
||||
/// Constructor: *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC)
|
||||
template <typename GemmKernel>
|
||||
ThreadblockSwizzleStreamK(
|
||||
KernelTraits<GemmKernel> kernel_traits_,
|
||||
GemmUniversalMode mode_,
|
||||
cutlass::conv::Operator conv_operator,
|
||||
cutlass::conv::Conv3dProblemSize const &problem_size_,
|
||||
GemmCoord tile_size_,
|
||||
int batch_count_,
|
||||
int sm_occupancy_,
|
||||
int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
||||
int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs
|
||||
int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles
|
||||
:
|
||||
ThreadblockSwizzleStreamK(
|
||||
kernel_traits_,
|
||||
mode_,
|
||||
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_),
|
||||
tile_size_,
|
||||
batch_count_,
|
||||
sm_occupancy_,
|
||||
avail_sms_,
|
||||
dp_tiles_,
|
||||
sk_blocks_)
|
||||
{}
|
||||
|
||||
/// Number of blocks performing useful work
|
||||
int get_num_active_blocks() const
|
||||
{
|
||||
return (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
|
||||
}
|
||||
|
||||
/// Obtains number of threadblocks per GEMM
|
||||
int get_num_blocks() const
|
||||
{
|
||||
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
|
||||
|
||||
if (work_blocks <= avail_sms * 2)
|
||||
int active_blocks = get_num_active_blocks();
|
||||
if (remap_block_indices)
|
||||
{
|
||||
return work_blocks;
|
||||
// Add padding blocks if we are performing remapping in order to dispatch a grid of at least four waves
|
||||
return fast_max(active_blocks, avail_sms * 4);
|
||||
}
|
||||
|
||||
return fast_max(work_blocks, avail_sms * 4);
|
||||
return active_blocks;
|
||||
}
|
||||
|
||||
|
||||
/// Obtains grid extents in CTAs
|
||||
dim3 get_grid_dims() const
|
||||
{
|
||||
return dim3(get_num_blocks(), 1, tiled_shape.k());
|
||||
return dim3(get_num_blocks(), 1, batch_count);
|
||||
}
|
||||
|
||||
|
||||
// Guards needed for PyCUTLASS library generation
|
||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
|
||||
//
|
||||
// Device-side interface
|
||||
//
|
||||
|
||||
/// Proves to the compiler that val is warp-uniform
|
||||
CUTLASS_DEVICE
|
||||
int uniform(int val) const
|
||||
{
|
||||
return __shfl_sync(0xffffffff, val, 0);
|
||||
}
|
||||
|
||||
/// Obtains number of threadblocks per GEMM
|
||||
CUTLASS_DEVICE
|
||||
int device_num_blocks() const
|
||||
@ -634,9 +655,16 @@ struct ThreadblockSwizzleStreamK {
|
||||
CUTLASS_DEVICE
|
||||
int get_sk_tile_idx(int iter) const
|
||||
{
|
||||
return div_mod.iters_per_tile.div(iter);
|
||||
int tile_idx = div_mod_iters_per_tile.div(iter);
|
||||
return uniform(tile_idx);
|
||||
}
|
||||
|
||||
/// Obtains the batch index
|
||||
CUTLASS_DEVICE
|
||||
int get_batch_idx() const
|
||||
{
|
||||
return RematerializeBlockIdxZ();
|
||||
}
|
||||
|
||||
/// Obtains the calling threadblock's tiled coordinates for the given tile index
|
||||
CUTLASS_DEVICE
|
||||
@ -644,12 +672,21 @@ struct ThreadblockSwizzleStreamK {
|
||||
{
|
||||
int m, n;
|
||||
|
||||
// row-major raster
|
||||
div_mod_tiled_shape_n(m, n, tile_idx);
|
||||
|
||||
if (tiled_shape().m() < tiled_shape().n())
|
||||
{
|
||||
// column-major raster
|
||||
div_mod_tiled_shape_m(n, m, tile_idx);
|
||||
}
|
||||
|
||||
if (cohort_raster)
|
||||
{
|
||||
// tiled cohort raster
|
||||
int cohort_tile_idx = tile_idx / kCtasPerCohort;
|
||||
int cohort_grid_m, cohort_grid_n;
|
||||
div_mod.tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx);
|
||||
div_mod_tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx);
|
||||
|
||||
int block_idx_cohort = tile_idx % kCtasPerCohort;
|
||||
int block_cohort_m = block_idx_cohort / kCohortCtasN;
|
||||
@ -658,44 +695,46 @@ struct ThreadblockSwizzleStreamK {
|
||||
m = (cohort_grid_m * kCohortCtasM) + block_cohort_m;
|
||||
n = (cohort_grid_n * kCohortCtasN) + block_cohort_n;
|
||||
}
|
||||
else if (tiled_shape.m() < tiled_shape.n())
|
||||
{
|
||||
// column-major raster
|
||||
div_mod.tiled_shape_m(n, m, tile_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
// row-major raster
|
||||
div_mod.tiled_shape_n(m, n, tile_idx);
|
||||
}
|
||||
|
||||
int block_idx_k = RematerializeBlockIdxZ();
|
||||
return GemmCoord{m, n, block_idx_k};
|
||||
return GemmCoord(m, n, get_batch_idx());
|
||||
}
|
||||
|
||||
/// Obtains the calling threadblock's tiled coordinates for the given tile index (row-major rastorization)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset_row_major(int tile_idx) const
|
||||
{
|
||||
// row-major raster
|
||||
int m, n;
|
||||
div_mod_tiled_shape_n(m, n, tile_idx);
|
||||
return GemmCoord(m, n, get_batch_idx());
|
||||
}
|
||||
|
||||
/// Obtains calling threadblock's linear threadblock index
|
||||
CUTLASS_DEVICE
|
||||
int get_block_idx() const
|
||||
{
|
||||
int block_idx = RematerializeBlockIdxX();
|
||||
|
||||
// Remap the block indices for the first two waves of thread blocks if
|
||||
// we have multi-occupancy and the grid constitutes four or more waves
|
||||
|
||||
int block_idx = RematerializeBlockIdxX();
|
||||
int num_blocks = device_num_blocks();
|
||||
int dest_sm = block_idx / 2;
|
||||
int dest_wave = block_idx % 2;
|
||||
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
|
||||
|
||||
if ((sm_occupancy > 1) &&
|
||||
(num_blocks >= avail_sms * 4) &&
|
||||
(block_idx < avail_sms * 2))
|
||||
if (remap_block_indices && (block_idx < avail_sms * 2))
|
||||
{
|
||||
int dest_sm = block_idx / 2;
|
||||
int dest_wave = block_idx % 2;
|
||||
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
|
||||
block_idx = remapped_block_idx;
|
||||
}
|
||||
|
||||
// Block-index is blockIdx.x for DP blocks
|
||||
return block_idx;
|
||||
// Remap block indices to interleave SK regions to limit intra-region waiting
|
||||
if (block_idx < sk_regions() * sk_blocks_per_region())
|
||||
{
|
||||
int block_in_region;
|
||||
int region;
|
||||
div_mod_sk_regions(block_in_region, region, block_idx);
|
||||
block_idx = (region * sk_blocks_per_region()) + block_in_region;
|
||||
}
|
||||
|
||||
return uniform(block_idx);
|
||||
}
|
||||
|
||||
|
||||
@ -705,19 +744,21 @@ struct ThreadblockSwizzleStreamK {
|
||||
{
|
||||
int region_idx;
|
||||
int iter_in_region;
|
||||
div_mod.sk_iters_per_region(region_idx, iter_in_region, iter);
|
||||
div_mod_sk_iters_per_region(region_idx, iter_in_region, iter);
|
||||
|
||||
int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block) + sk_big_blocks_per_region; // number of iterations in the region's big blocks
|
||||
int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block()) + sk_big_blocks_per_region; // number of iterations in the region's big blocks
|
||||
int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal bocks
|
||||
|
||||
int big_block_idx_in_region = div_mod.sk_iters_per_big_block.div(iter_in_region);
|
||||
int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod.sk_iters_per_normal_block.div(normal_block_iters);
|
||||
int big_block_idx_in_region = div_mod_sk_iters_per_big_block.div(iter_in_region);
|
||||
int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod_sk_iters_per_normal_block.div(normal_block_iters);
|
||||
|
||||
int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ?
|
||||
big_block_idx_in_region :
|
||||
normal_block_idx_in_region;
|
||||
|
||||
return (sk_blocks_per_region * region_idx) + block_idx_in_region;
|
||||
int owning_block_idx = (sk_blocks_per_region() * region_idx) + block_idx_in_region;
|
||||
|
||||
return owning_block_idx;
|
||||
}
|
||||
|
||||
/// Obtains iteration extends for the given SK block index
|
||||
@ -729,12 +770,12 @@ struct ThreadblockSwizzleStreamK {
|
||||
{
|
||||
int region_idx;
|
||||
int block_idx_in_region;
|
||||
div_mod.sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx);
|
||||
div_mod_sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx);
|
||||
|
||||
block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block);
|
||||
block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block());
|
||||
|
||||
// Adjust extents for the first "num_big_blocks" blocks that get one extra iteration
|
||||
int block_iters = sk_iters_per_normal_block;
|
||||
int block_iters = sk_iters_per_normal_block();
|
||||
if (block_idx_in_region < sk_big_blocks_per_region) {
|
||||
// This is a +1 iteration block
|
||||
block_iter_begin += block_idx_in_region;
|
||||
@ -756,10 +797,12 @@ struct ThreadblockSwizzleStreamK {
|
||||
return block_idx;
|
||||
}
|
||||
|
||||
int iter = tile_idx * iters_per_tile;
|
||||
int iter = tile_idx * iters_per_tile();
|
||||
return get_sk_block_idx(iter);
|
||||
}
|
||||
|
||||
#endif // defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -50,9 +50,9 @@ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8):
|
||||
#
|
||||
def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
|
||||
alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
|
||||
swizzling_functor = SwizzlingFunctor.Identity8):
|
||||
# swizzling_functor = SwizzlingFunctor.Identity8):
|
||||
# Use StreamK decomposition for basic GEMMs
|
||||
# swizzling_functor = SwizzlingFunctor.StreamK):
|
||||
swizzling_functor = SwizzlingFunctor.StreamK):
|
||||
|
||||
if complex_transforms is None:
|
||||
complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
|
||||
@ -4600,6 +4600,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
|
||||
help='Specify the output log file containing all enabled kernels in this build')
|
||||
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
|
||||
parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user