streamk example and performance tuning (#760)
* streamk example and performance tuning * one missing file Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
a1046d49c1
commit
764b840d6f
@ -1,7 +1,7 @@
|
|||||||
# NVIDIA CUTLASS Changelog
|
# NVIDIA CUTLASS Changelog
|
||||||
|
|
||||||
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19)
|
||||||
* Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
* [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||||
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel.
|
||||||
* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency.
|
||||||
* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||||
|
|||||||
@ -39,7 +39,7 @@ supported at each level of the execution model hierarchy.
|
|||||||
# What's New in CUTLASS 2.11
|
# What's New in CUTLASS 2.11
|
||||||
|
|
||||||
CUTLASS 2.11 is an update to CUTLASS adding:
|
CUTLASS 2.11 is an update to CUTLASS adding:
|
||||||
- Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
- [Stream-K](/examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one.
|
||||||
- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths.
|
- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths.
|
||||||
- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel.
|
- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel.
|
||||||
- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8.
|
||||||
@ -115,7 +115,7 @@ any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
|||||||
|NVIDIA A100|8.0|11.0|11.0|
|
|NVIDIA A100|8.0|11.0|11.0|
|
||||||
|NVIDIA A10 |8.6|11.1|11.1|
|
|NVIDIA A10 |8.6|11.1|11.1|
|
||||||
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
|NVIDIA GeForce 3090|8.6|11.1|11.1|
|
||||||
|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8|
|
|NVIDIA H100 PCIe|9.0|11.8|Double-precision: 11.8; Mixed precision: 12.0|
|
||||||
|
|
||||||
# Documentation
|
# Documentation
|
||||||
|
|
||||||
|
|||||||
35
examples/47_ampere_gemm_universal_streamk/CMakeLists.txt
Normal file
35
examples/47_ampere_gemm_universal_streamk/CMakeLists.txt
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
|
||||||
|
# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
|
||||||
|
cutlass_example_add_executable(
|
||||||
|
47_ampere_gemm_universal_streamk
|
||||||
|
ampere_gemm_universal_streamk.cu
|
||||||
|
)
|
||||||
|
|
||||||
@ -0,0 +1,592 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/***************************************************************************************************
|
||||||
|
Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the
|
||||||
|
"classic data-parallel" and "Split-K" decompositions.
|
||||||
|
|
||||||
|
For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition
|
||||||
|
for Dense Matrix-Matrix Multiplication on the GPU" <todo: link>
|
||||||
|
|
||||||
|
Requires NVIDIA Ampere or newer device (SM80+).
|
||||||
|
|
||||||
|
- To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100)
|
||||||
|
|
||||||
|
cutlass$ sudo nvidia-smi -pm 1 -i 0
|
||||||
|
|
||||||
|
cutlass$ sudo nvidia-smi -i 0 -pl 400
|
||||||
|
|
||||||
|
cutlass$ sudo nvidia-smi -i 0 -lgc 1005
|
||||||
|
|
||||||
|
- Build and run:
|
||||||
|
|
||||||
|
cutlass$ mkdir build
|
||||||
|
|
||||||
|
cutlass$ cd build
|
||||||
|
|
||||||
|
cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80
|
||||||
|
|
||||||
|
cutlass/build$ make 47_ampere_gemm_universal_streamk
|
||||||
|
|
||||||
|
cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk
|
||||||
|
|
||||||
|
10000 timing iterations of 2048 x 2048 x 2048 matrix-matrix multiply
|
||||||
|
|
||||||
|
Basic data-parallel GEMM
|
||||||
|
Disposition: Passed
|
||||||
|
Avg runtime: 0.112633 ms
|
||||||
|
GFLOPs: 152530
|
||||||
|
|
||||||
|
StreamK GEMM with default load-balancing
|
||||||
|
Disposition: Passed
|
||||||
|
Avg runtime: 0.0941929 ms
|
||||||
|
GFLOPs: 182390
|
||||||
|
Speedup vs Basic-DP: 1.196
|
||||||
|
|
||||||
|
StreamK emulating basic data-parallel GEMM
|
||||||
|
Disposition: Passed
|
||||||
|
Avg runtime: 0.113119 ms
|
||||||
|
GFLOPs: 151875
|
||||||
|
Speedup vs Basic-DP: 0.996
|
||||||
|
|
||||||
|
Basic split-K GEMM with tile-splitting factor 2
|
||||||
|
Disposition: Passed
|
||||||
|
Avg runtime: 0.104772 ms
|
||||||
|
GFLOPs: 163973
|
||||||
|
|
||||||
|
StreamK emulating Split-K GEMM with tile-splitting factor 2
|
||||||
|
Disposition: Passed
|
||||||
|
Avg runtime: 0.105379 ms
|
||||||
|
GFLOPs: 163029
|
||||||
|
Speedup vs Basic-SplitK: 0.994
|
||||||
|
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/command_line.h"
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/reference/device/gemm.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// A matrix configuration
|
||||||
|
using ElementA = cutlass::half_t; // Element type for A matrix operand
|
||||||
|
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||||
|
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// B matrix configuration
|
||||||
|
using ElementB = cutlass::half_t; // Element type for B matrix operand
|
||||||
|
using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand
|
||||||
|
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// C/D matrix configuration
|
||||||
|
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||||
|
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||||
|
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes)
|
||||||
|
|
||||||
|
// Multiply-accumulate blocking/pipelining details
|
||||||
|
using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation
|
||||||
|
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||||
|
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
|
||||||
|
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
|
||||||
|
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
|
||||||
|
|
||||||
|
// Epilogue output operator
|
||||||
|
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||||
|
ElementC, // Element type for C and D matrix operands
|
||||||
|
AlignmentC, // Memory access granularity of C and D matrix in units of elements
|
||||||
|
ElementAccumulator, // Element type from internal accumaccumulation
|
||||||
|
ElementAccumulator>; // Data type used to compute linear combination
|
||||||
|
|
||||||
|
// Reference device GEMM implementation type
|
||||||
|
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||||
|
ElementA,
|
||||||
|
LayoutA,
|
||||||
|
ElementB,
|
||||||
|
LayoutB,
|
||||||
|
ElementC,
|
||||||
|
LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementAccumulator>;
|
||||||
|
|
||||||
|
// Classic data-parallel device GEMM implementation type
|
||||||
|
using DeviceGemmBasic = cutlass::gemm::device::GemmUniversal<
|
||||||
|
ElementA, LayoutA,
|
||||||
|
ElementB, LayoutB,
|
||||||
|
ElementC, LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
OperatorClass,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOp,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
NumStages,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB>;
|
||||||
|
|
||||||
|
// StreamK device GEMM implementation type
|
||||||
|
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal<
|
||||||
|
ElementA, LayoutA,
|
||||||
|
ElementB, LayoutB,
|
||||||
|
ElementC, LayoutC,
|
||||||
|
ElementAccumulator,
|
||||||
|
OperatorClass,
|
||||||
|
ArchTag,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOp,
|
||||||
|
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference
|
||||||
|
NumStages,
|
||||||
|
AlignmentA,
|
||||||
|
AlignmentB>;
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Testbed utility types
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Result structure
|
||||||
|
struct Result
|
||||||
|
{
|
||||||
|
double avg_runtime_ms;
|
||||||
|
double gflops;
|
||||||
|
cutlass::Status status;
|
||||||
|
cudaError_t error;
|
||||||
|
bool passed;
|
||||||
|
|
||||||
|
Result(
|
||||||
|
double avg_runtime_ms = 0,
|
||||||
|
double gflops = 0,
|
||||||
|
cutlass::Status status = cutlass::Status::kSuccess,
|
||||||
|
cudaError_t error = cudaSuccess)
|
||||||
|
:
|
||||||
|
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
|
||||||
|
{}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Command line options parsing
|
||||||
|
struct Options
|
||||||
|
{
|
||||||
|
std::string command_name;
|
||||||
|
bool help;
|
||||||
|
cutlass::gemm::GemmCoord problem_size;
|
||||||
|
float alpha;
|
||||||
|
float beta;
|
||||||
|
int split_k_factor;
|
||||||
|
int avail_sms;
|
||||||
|
bool reference_check;
|
||||||
|
int iterations;
|
||||||
|
|
||||||
|
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
|
||||||
|
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> tensor_c;
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
|
||||||
|
|
||||||
|
Options(std::string command_name) :
|
||||||
|
command_name(command_name),
|
||||||
|
help(false),
|
||||||
|
problem_size({2048, 2048, 2048}),
|
||||||
|
alpha(1.0f),
|
||||||
|
beta(0.0f),
|
||||||
|
split_k_factor(1),
|
||||||
|
avail_sms(-1), // Number of device SMs to use is unlimited
|
||||||
|
reference_check(true),
|
||||||
|
iterations(10000)
|
||||||
|
{}
|
||||||
|
|
||||||
|
bool valid() const
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void parse(int argc, char const **args)
|
||||||
|
{
|
||||||
|
cutlass::CommandLine cmd(argc, args);
|
||||||
|
|
||||||
|
if (cmd.check_cmd_line_flag("help")) {
|
||||||
|
help = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.get_cmd_line_argument("m", problem_size.m());
|
||||||
|
cmd.get_cmd_line_argument("n", problem_size.n());
|
||||||
|
cmd.get_cmd_line_argument("k", problem_size.k());
|
||||||
|
cmd.get_cmd_line_argument("alpha", alpha);
|
||||||
|
cmd.get_cmd_line_argument("beta", beta);
|
||||||
|
cmd.get_cmd_line_argument("split", split_k_factor);
|
||||||
|
cmd.get_cmd_line_argument("iterations", iterations);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prints the usage statement.
|
||||||
|
std::ostream & print_usage(std::ostream &out) const
|
||||||
|
{
|
||||||
|
out
|
||||||
|
<< "Performs a GEMM computation.\n"
|
||||||
|
<< "\n"
|
||||||
|
<< "Options:\n"
|
||||||
|
<< "\n"
|
||||||
|
<< " --help If specified, displays this usage statement.\n\n"
|
||||||
|
<< " --m=<int> GEMM M dimension\n"
|
||||||
|
<< " --n=<int> GEMM N dimension\n"
|
||||||
|
<< " --k=<int> GEMM K dimension\n"
|
||||||
|
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||||
|
<< " --beta=<f32> Epilogue scalar beta\n\n"
|
||||||
|
<< " --split=<int> Split-K factor to emulate\n\n"
|
||||||
|
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||||
|
|
||||||
|
out
|
||||||
|
<< "\n\nExamples:\n\n"
|
||||||
|
<< "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute performance in GFLOP/s
|
||||||
|
double gflops(double runtime_s) const
|
||||||
|
{
|
||||||
|
// Two flops per multiply-add
|
||||||
|
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// GEMM evaluation
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options
|
||||||
|
typename DeviceGemmBasic::Arguments args_from_options(
|
||||||
|
const DeviceGemmBasic &device_gemm,
|
||||||
|
const Options &options,
|
||||||
|
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
|
||||||
|
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_c,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_d)
|
||||||
|
{
|
||||||
|
return typename DeviceGemmBasic::Arguments(
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
|
||||||
|
options.problem_size, // problem_size
|
||||||
|
options.split_k_factor, // batch count / splitk slices
|
||||||
|
{ // epilogue parameters
|
||||||
|
ElementAccumulator(options.alpha),
|
||||||
|
ElementAccumulator(options.beta)
|
||||||
|
},
|
||||||
|
tensor_a.device_data(), // ptr_A
|
||||||
|
tensor_b.device_data(), // ptr_B
|
||||||
|
tensor_c.device_data(), // ptr_C
|
||||||
|
tensor_d.device_data(), // ptr_D
|
||||||
|
options.problem_size.mk().product(), // batch_stride_A
|
||||||
|
options.problem_size.nk().product(), // batch_stride_B
|
||||||
|
options.problem_size.mn().product(), // batch_stride_C
|
||||||
|
options.problem_size.mn().product(), // batch_stride_D
|
||||||
|
tensor_a.layout().stride(0), // stride_a
|
||||||
|
tensor_b.layout().stride(0), // stride_b
|
||||||
|
tensor_c.layout().stride(0), // stride_c
|
||||||
|
tensor_d.layout().stride(0)); // stride_d
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
|
||||||
|
typename DeviceGemmStreamK::Arguments args_from_options(
|
||||||
|
const DeviceGemmStreamK &device_gemm,
|
||||||
|
const Options &options,
|
||||||
|
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
|
||||||
|
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_c,
|
||||||
|
cutlass::HostTensor<ElementC, LayoutC> &tensor_d)
|
||||||
|
{
|
||||||
|
return typename DeviceGemmStreamK::Arguments(
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
|
||||||
|
options.problem_size, // problem_size
|
||||||
|
options.split_k_factor, // batch count / splitk slices
|
||||||
|
{ // epilogue parameters
|
||||||
|
ElementAccumulator(options.alpha),
|
||||||
|
ElementAccumulator(options.beta)
|
||||||
|
},
|
||||||
|
tensor_a.device_data(), // ptr_A
|
||||||
|
tensor_b.device_data(), // ptr_B
|
||||||
|
tensor_c.device_data(), // ptr_C
|
||||||
|
tensor_d.device_data(), // ptr_D
|
||||||
|
options.problem_size.mk().product(), // batch_stride_A
|
||||||
|
options.problem_size.nk().product(), // batch_stride_B
|
||||||
|
options.problem_size.mn().product(), // batch_stride_C
|
||||||
|
options.problem_size.mn().product(), // batch_stride_D
|
||||||
|
tensor_a.layout().stride(0), // stride_a
|
||||||
|
tensor_b.layout().stride(0), // stride_b
|
||||||
|
tensor_c.layout().stride(0), // stride_c
|
||||||
|
tensor_d.layout().stride(0), // stride_d
|
||||||
|
options.avail_sms); // avail_sms
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Execute a given example GEMM computation
|
||||||
|
template <typename DeviceGemmT>
|
||||||
|
Result run(std::string description, Options &options)
|
||||||
|
{
|
||||||
|
// Display test description
|
||||||
|
std::cout << std::endl << description << std::endl;
|
||||||
|
|
||||||
|
// Zero-initialize test output matrix D
|
||||||
|
cutlass::reference::host::TensorFill(options.tensor_d.host_view());
|
||||||
|
options.tensor_d.sync_device();
|
||||||
|
|
||||||
|
// Instantiate CUTLASS kernel depending on templates
|
||||||
|
DeviceGemmT device_gemm;
|
||||||
|
|
||||||
|
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
|
||||||
|
auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d);
|
||||||
|
|
||||||
|
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||||
|
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
|
||||||
|
|
||||||
|
// Allocate workspace memory
|
||||||
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
// Check the problem size is supported or not
|
||||||
|
CUTLASS_CHECK(device_gemm.can_implement(arguments));
|
||||||
|
|
||||||
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||||
|
CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
|
||||||
|
|
||||||
|
// Correctness / Warmup iteration
|
||||||
|
CUTLASS_CHECK(device_gemm());
|
||||||
|
|
||||||
|
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||||
|
options.tensor_d.sync_host();
|
||||||
|
|
||||||
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||||
|
Result result;
|
||||||
|
result.passed = cutlass::reference::host::TensorEquals(
|
||||||
|
options.tensor_d.host_view(),
|
||||||
|
options.tensor_ref_d.host_view());
|
||||||
|
|
||||||
|
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
|
||||||
|
|
||||||
|
// Run profiling loop
|
||||||
|
if (options.iterations > 0)
|
||||||
|
{
|
||||||
|
GpuTimer timer;
|
||||||
|
timer.start();
|
||||||
|
for (int iter = 0; iter < options.iterations; ++iter) {
|
||||||
|
CUTLASS_CHECK(device_gemm());
|
||||||
|
}
|
||||||
|
timer.stop();
|
||||||
|
|
||||||
|
// Compute average runtime and GFLOPs.
|
||||||
|
float elapsed_ms = timer.elapsed_millis();
|
||||||
|
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||||
|
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);
|
||||||
|
|
||||||
|
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
|
||||||
|
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!result.passed) {
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Program entrypoint
|
||||||
|
int main(int argc, const char **argv)
|
||||||
|
{
|
||||||
|
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
|
||||||
|
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
|
||||||
|
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
||||||
|
|
||||||
|
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Current device must must have compute capability at least 80
|
||||||
|
cudaDeviceProp props;
|
||||||
|
int current_device_id;
|
||||||
|
CUDA_CHECK(cudaGetDevice(¤t_device_id));
|
||||||
|
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
|
||||||
|
if (!((props.major * 10 + props.minor) >= 80))
|
||||||
|
{
|
||||||
|
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
|
||||||
|
<< std::endl;
|
||||||
|
|
||||||
|
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse commandline options
|
||||||
|
Options options("ampere_streamk_gemm");
|
||||||
|
options.parse(argc, argv);
|
||||||
|
|
||||||
|
if (options.help) {
|
||||||
|
options.print_usage(std::cout) << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout <<
|
||||||
|
options.iterations << " timing iterations of " <<
|
||||||
|
options.problem_size.m() << " x " <<
|
||||||
|
options.problem_size.n() << " x " <<
|
||||||
|
options.problem_size.k() << " matrix-matrix multiply" << std::endl;
|
||||||
|
|
||||||
|
if (!options.valid()) {
|
||||||
|
std::cerr << "Invalid problem." << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// Initialize GEMM datasets
|
||||||
|
//
|
||||||
|
|
||||||
|
// Initialize tensors using CUTLASS helper functions
|
||||||
|
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||||
|
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||||
|
options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||||
|
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
|
||||||
|
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
|
||||||
|
|
||||||
|
// Fill matrix A on host with uniform-random data [4, -4]
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
options.tensor_a.host_view(),
|
||||||
|
1,
|
||||||
|
ElementA(2),
|
||||||
|
ElementA(-2),
|
||||||
|
0);
|
||||||
|
|
||||||
|
// Fill matrix B on host with uniform-random data [4, -4]
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
options.tensor_b.host_view(),
|
||||||
|
1,
|
||||||
|
ElementB(2),
|
||||||
|
ElementB(-2),
|
||||||
|
0);
|
||||||
|
|
||||||
|
// Fill matrix C on host with uniform-random data [4, -4]
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
options.tensor_c.host_view(),
|
||||||
|
1,
|
||||||
|
ElementC(2),
|
||||||
|
ElementC(-2),
|
||||||
|
0);
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// Compute reference output
|
||||||
|
//
|
||||||
|
|
||||||
|
// Copy data from host to GPU
|
||||||
|
options.tensor_a.sync_device();
|
||||||
|
options.tensor_b.sync_device();
|
||||||
|
options.tensor_c.sync_device();
|
||||||
|
|
||||||
|
// Zero-initialize reference output matrix D
|
||||||
|
cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
|
||||||
|
options.tensor_ref_d.sync_device();
|
||||||
|
|
||||||
|
// Create instantiation for device reference gemm kernel
|
||||||
|
DeviceGemmReference gemm_reference;
|
||||||
|
|
||||||
|
// Launch device reference gemm kernel
|
||||||
|
gemm_reference(
|
||||||
|
options.problem_size,
|
||||||
|
ElementAccumulator(options.alpha),
|
||||||
|
options.tensor_a.device_ref(),
|
||||||
|
options.tensor_b.device_ref(),
|
||||||
|
ElementAccumulator(options.beta),
|
||||||
|
options.tensor_c.device_ref(),
|
||||||
|
options.tensor_ref_d.device_ref());
|
||||||
|
|
||||||
|
// Wait for kernels to finish
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
// Copy output data from reference kernel to host for comparison
|
||||||
|
options.tensor_ref_d.sync_host();
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// Evaluate CUTLASS kernels
|
||||||
|
//
|
||||||
|
|
||||||
|
// Test default operation
|
||||||
|
if (options.split_k_factor == 1)
|
||||||
|
{
|
||||||
|
// Compare basic data-parallel version versus StreamK version using default load-balancing heuristics
|
||||||
|
Result basic_dp = run<DeviceGemmBasic>("Basic data-parallel GEMM", options);
|
||||||
|
Result streamk_default = run<DeviceGemmStreamK>("StreamK GEMM with default load-balancing", options);
|
||||||
|
|
||||||
|
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms));
|
||||||
|
|
||||||
|
// Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1
|
||||||
|
options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing)
|
||||||
|
Result streamk_dp = run<DeviceGemmStreamK>("StreamK emulating basic data-parallel GEMM", options);
|
||||||
|
options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs)
|
||||||
|
|
||||||
|
printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms));
|
||||||
|
|
||||||
|
options.split_k_factor++; // Increment splitting factor for next evaluation
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show that StreamK can emulate "Split-K" with a tile-splitting factor
|
||||||
|
Result basic_splitk = run<DeviceGemmBasic>(
|
||||||
|
std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
|
||||||
|
options);
|
||||||
|
|
||||||
|
Result streamk_splitk = run<DeviceGemmStreamK>(
|
||||||
|
std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor),
|
||||||
|
options);
|
||||||
|
|
||||||
|
printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms));
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@ -124,6 +124,7 @@ foreach(EXAMPLE
|
|||||||
43_ell_block_sparse_gemm
|
43_ell_block_sparse_gemm
|
||||||
45_dual_gemm
|
45_dual_gemm
|
||||||
46_depthwise_simt_conv2dfprop
|
46_depthwise_simt_conv2dfprop
|
||||||
|
47_ampere_gemm_universal_streamk
|
||||||
)
|
)
|
||||||
|
|
||||||
add_subdirectory(${EXAMPLE})
|
add_subdirectory(${EXAMPLE})
|
||||||
|
|||||||
@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
#include "cuda_runtime.h"
|
#include "cuda_runtime.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Panic wrapper for unwinding CUTLASS errors
|
||||||
|
*/
|
||||||
#define CUTLASS_CHECK(status) \
|
#define CUTLASS_CHECK(status) \
|
||||||
{ \
|
{ \
|
||||||
cutlass::Status error = status; \
|
cutlass::Status error = status; \
|
||||||
@ -12,6 +15,10 @@
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Panic wrapper for unwinding CUDA runtime errors
|
||||||
|
*/
|
||||||
#define CUDA_CHECK(status) \
|
#define CUDA_CHECK(status) \
|
||||||
{ \
|
{ \
|
||||||
cudaError_t error = status; \
|
cudaError_t error = status; \
|
||||||
@ -21,3 +28,50 @@
|
|||||||
exit(EXIT_FAILURE); \
|
exit(EXIT_FAILURE); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GPU timer for recording the elapsed time across kernel(s) launched in GPU stream
|
||||||
|
*/
|
||||||
|
struct GpuTimer
|
||||||
|
{
|
||||||
|
cudaStream_t _stream_id;
|
||||||
|
cudaEvent_t _start;
|
||||||
|
cudaEvent_t _stop;
|
||||||
|
|
||||||
|
/// Constructor
|
||||||
|
GpuTimer() : _stream_id(0)
|
||||||
|
{
|
||||||
|
CUDA_CHECK(cudaEventCreate(&_start));
|
||||||
|
CUDA_CHECK(cudaEventCreate(&_stop));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Destructor
|
||||||
|
~GpuTimer()
|
||||||
|
{
|
||||||
|
CUDA_CHECK(cudaEventDestroy(_start));
|
||||||
|
CUDA_CHECK(cudaEventDestroy(_stop));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the timer for a given stream (defaults to the default stream)
|
||||||
|
void start(cudaStream_t stream_id = 0)
|
||||||
|
{
|
||||||
|
_stream_id = stream_id;
|
||||||
|
CUDA_CHECK(cudaEventRecord(_start, _stream_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stop the timer
|
||||||
|
void stop()
|
||||||
|
{
|
||||||
|
CUDA_CHECK(cudaEventRecord(_stop, _stream_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the elapsed time (in milliseconds)
|
||||||
|
float elapsed_millis()
|
||||||
|
{
|
||||||
|
float elapsed = 0.0;
|
||||||
|
CUDA_CHECK(cudaEventSynchronize(_stop));
|
||||||
|
CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop));
|
||||||
|
return elapsed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
@ -57,15 +57,18 @@ public:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
/// Load flag, as a strong operation (int specialization)
|
/// Load flag, as a strong acquire operation (int specialization)
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
static int ld_strong(int *ptr)
|
static int ld_acquire(int *ptr)
|
||||||
{
|
{
|
||||||
int state = 0;
|
int state = 0;
|
||||||
|
|
||||||
#if (__CUDA_ARCH__ >= 700)
|
#if (__CUDA_ARCH__ >= 700)
|
||||||
/// SM70 and newer use memory consistency qualifiers
|
/// SM70 and newer use memory consistency qualifiers
|
||||||
asm volatile ("ld.global.relaxed.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
|
||||||
|
// Acquire pattern using acquire modifier
|
||||||
|
asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||||
|
|
||||||
#else
|
#else
|
||||||
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
|
||||||
#endif // (__CUDA_ARCH__ >= 700)
|
#endif // (__CUDA_ARCH__ >= 700)
|
||||||
@ -73,18 +76,6 @@ protected:
|
|||||||
return state;
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Store flag, as a strong operation (int specialization)
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
static void st_strong(int *ptr, int val)
|
|
||||||
{
|
|
||||||
#if (__CUDA_ARCH__ >= 700)
|
|
||||||
/// SM70 and newer use memory consistency qualifiers
|
|
||||||
asm volatile ("st.global.relaxed.gpu.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
|
||||||
#else
|
|
||||||
asm volatile ("st.cg.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
|
||||||
#endif // (__CUDA_ARCH__ >= 700)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/// Reduce into flag, with release pattern (int specialization)
|
/// Reduce into flag, with release pattern (int specialization)
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
@ -93,7 +84,12 @@ protected:
|
|||||||
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||||
#if (__CUDA_ARCH__ >= 700)
|
#if (__CUDA_ARCH__ >= 700)
|
||||||
/// SM70 and newer use memory consistency qualifiers
|
/// SM70 and newer use memory consistency qualifiers
|
||||||
asm volatile ("red.release.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
|
||||||
|
// Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data
|
||||||
|
// that was weakly-written by other threads prior to the last syncthreads)
|
||||||
|
asm volatile ("fence.acq_rel.gpu;\n");
|
||||||
|
asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
|
||||||
|
|
||||||
#else
|
#else
|
||||||
__threadfence();
|
__threadfence();
|
||||||
atomicAdd(ptr, val);
|
atomicAdd(ptr, val);
|
||||||
@ -115,7 +111,7 @@ public:
|
|||||||
{
|
{
|
||||||
// Spin-loop
|
// Spin-loop
|
||||||
#pragma unroll 1
|
#pragma unroll 1
|
||||||
while(ld_strong(flag_ptr) < count) {}
|
while(ld_acquire(flag_ptr) < count) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -133,9 +129,8 @@ public:
|
|||||||
{
|
{
|
||||||
// Spin-loop
|
// Spin-loop
|
||||||
#pragma unroll 1
|
#pragma unroll 1
|
||||||
while(ld_strong(flag_ptr) != val) {}
|
while(ld_acquire(flag_ptr) != val) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -166,7 +161,8 @@ public:
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (thread_idx == 0) {
|
if (thread_idx == 0)
|
||||||
|
{
|
||||||
red_release(flag_ptr, 1);
|
red_release(flag_ptr, 1);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -124,7 +124,7 @@ public:
|
|||||||
|
|
||||||
GemmUniversalMode mode;
|
GemmUniversalMode mode;
|
||||||
GemmCoord problem_size;
|
GemmCoord problem_size;
|
||||||
int batch_count;
|
int batch_count; // Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor
|
||||||
|
|
||||||
typename EpilogueOutputOp::Params epilogue;
|
typename EpilogueOutputOp::Params epilogue;
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ public:
|
|||||||
typename LayoutC::Stride::LongIndex ldc;
|
typename LayoutC::Stride::LongIndex ldc;
|
||||||
typename LayoutC::Stride::LongIndex ldd;
|
typename LayoutC::Stride::LongIndex ldd;
|
||||||
|
|
||||||
int sm_limit; /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
int avail_sms; /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -159,15 +159,18 @@ public:
|
|||||||
Arguments():
|
Arguments():
|
||||||
mode(GemmUniversalMode::kGemm),
|
mode(GemmUniversalMode::kGemm),
|
||||||
batch_count(1),
|
batch_count(1),
|
||||||
ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr),
|
ptr_A(nullptr),
|
||||||
sm_limit(-1)
|
ptr_B(nullptr),
|
||||||
|
ptr_C(nullptr),
|
||||||
|
ptr_D(nullptr),
|
||||||
|
avail_sms(-1)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
/// Constructor
|
/// Constructor
|
||||||
Arguments(
|
Arguments(
|
||||||
GemmUniversalMode mode,
|
GemmUniversalMode mode,
|
||||||
GemmCoord problem_size,
|
GemmCoord problem_size,
|
||||||
int batch_count,
|
int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
|
||||||
typename EpilogueOutputOp::Params epilogue,
|
typename EpilogueOutputOp::Params epilogue,
|
||||||
void const * ptr_A,
|
void const * ptr_A,
|
||||||
void const * ptr_B,
|
void const * ptr_B,
|
||||||
@ -181,15 +184,15 @@ public:
|
|||||||
typename LayoutB::Stride stride_b,
|
typename LayoutB::Stride stride_b,
|
||||||
typename LayoutC::Stride stride_c,
|
typename LayoutC::Stride stride_c,
|
||||||
typename LayoutC::Stride stride_d,
|
typename LayoutC::Stride stride_d,
|
||||||
int sm_limit = -1 /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
|
||||||
):
|
):
|
||||||
mode(mode),
|
mode(mode),
|
||||||
problem_size(problem_size),
|
problem_size(problem_size),
|
||||||
batch_count(batch_count),
|
batch_count(batch_split),
|
||||||
epilogue(epilogue),
|
epilogue(epilogue),
|
||||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||||
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), sm_limit(sm_limit)
|
stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), avail_sms(avail_sms)
|
||||||
{
|
{
|
||||||
CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size);
|
CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size);
|
||||||
}
|
}
|
||||||
@ -198,7 +201,7 @@ public:
|
|||||||
Arguments(
|
Arguments(
|
||||||
GemmUniversalMode mode,
|
GemmUniversalMode mode,
|
||||||
GemmCoord problem_size,
|
GemmCoord problem_size,
|
||||||
int batch_count,
|
int batch_split, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
|
||||||
typename EpilogueOutputOp::Params epilogue,
|
typename EpilogueOutputOp::Params epilogue,
|
||||||
void const * ptr_A,
|
void const * ptr_A,
|
||||||
void const * ptr_B,
|
void const * ptr_B,
|
||||||
@ -212,15 +215,15 @@ public:
|
|||||||
typename LayoutB::Stride::LongIndex ldb,
|
typename LayoutB::Stride::LongIndex ldb,
|
||||||
typename LayoutC::Stride::LongIndex ldc,
|
typename LayoutC::Stride::LongIndex ldc,
|
||||||
typename LayoutC::Stride::LongIndex ldd,
|
typename LayoutC::Stride::LongIndex ldd,
|
||||||
int sm_limit = -1 /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
int avail_sms = -1 /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
|
||||||
):
|
):
|
||||||
mode(mode),
|
mode(mode),
|
||||||
problem_size(problem_size),
|
problem_size(problem_size),
|
||||||
batch_count(batch_count),
|
batch_count(batch_split),
|
||||||
epilogue(epilogue),
|
epilogue(epilogue),
|
||||||
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D),
|
||||||
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
||||||
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), sm_limit(sm_limit)
|
lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), avail_sms(avail_sms)
|
||||||
{
|
{
|
||||||
stride_a = make_Coord(lda);
|
stride_a = make_Coord(lda);
|
||||||
stride_b = make_Coord(ldb);
|
stride_b = make_Coord(ldb);
|
||||||
@ -254,29 +257,36 @@ public:
|
|||||||
// Data members
|
// Data members
|
||||||
//
|
//
|
||||||
|
|
||||||
ThreadblockSwizzle block_mapping;
|
void * ptr_A;
|
||||||
|
void * ptr_B;
|
||||||
|
|
||||||
typename Mma::IteratorA::Params params_A;
|
typename Mma::IteratorA::Params params_A;
|
||||||
typename Mma::IteratorB::Params params_B;
|
typename Mma::IteratorB::Params params_B;
|
||||||
typename Epilogue::OutputTileIterator::Params params_C;
|
|
||||||
typename Epilogue::OutputTileIterator::Params params_D;
|
|
||||||
typename EpilogueOutputOp::Params output_op;
|
|
||||||
|
|
||||||
GemmUniversalMode mode;
|
|
||||||
|
|
||||||
void * ptr_A;
|
|
||||||
void * ptr_B;
|
|
||||||
void * ptr_C;
|
|
||||||
void * ptr_D;
|
|
||||||
|
|
||||||
int64_t batch_stride_A;
|
int64_t batch_stride_A;
|
||||||
int64_t batch_stride_B;
|
int64_t batch_stride_B;
|
||||||
int64_t batch_stride_C;
|
|
||||||
int64_t batch_stride_D;
|
GemmUniversalMode mode;
|
||||||
|
|
||||||
|
ThreadblockSwizzle block_mapping;
|
||||||
|
|
||||||
|
bool quick_dp;
|
||||||
|
|
||||||
void *barrier_workspace;
|
void *barrier_workspace;
|
||||||
void *partials_workspace;
|
void *partials_workspace;
|
||||||
|
|
||||||
|
typename EpilogueOutputOp::Params output_op;
|
||||||
|
|
||||||
|
void * ptr_D;
|
||||||
|
void * ptr_C;
|
||||||
|
|
||||||
|
typename Epilogue::OutputTileIterator::Params params_D;
|
||||||
|
typename Epilogue::OutputTileIterator::Params params_C;
|
||||||
|
|
||||||
|
int64_t batch_stride_D;
|
||||||
|
int64_t batch_stride_C;
|
||||||
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -295,7 +305,7 @@ public:
|
|||||||
{
|
{
|
||||||
// For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction,
|
// For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction,
|
||||||
// each reduction block needs its own synchronization flag.
|
// each reduction block needs its own synchronization flag.
|
||||||
int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region;
|
int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region();
|
||||||
int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks);
|
int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks);
|
||||||
|
|
||||||
return cacheline_align_up(sizeof(typename Barrier::T) * num_flags);
|
return cacheline_align_up(sizeof(typename Barrier::T) * num_flags);
|
||||||
@ -304,7 +314,7 @@ public:
|
|||||||
/// Get the workspace size needed for intermediate partial sums
|
/// Get the workspace size needed for intermediate partial sums
|
||||||
size_t get_partials_workspace_size() const
|
size_t get_partials_workspace_size() const
|
||||||
{
|
{
|
||||||
int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region;
|
int sk_blocks = block_mapping.sk_regions() * block_mapping.sk_blocks_per_region();
|
||||||
return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks);
|
return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -343,9 +353,9 @@ public:
|
|||||||
partials_workspace(nullptr)
|
partials_workspace(nullptr)
|
||||||
{
|
{
|
||||||
// Number of SMs to make available for StreamK decomposition
|
// Number of SMs to make available for StreamK decomposition
|
||||||
int avail_sms = (args.sm_limit == -1) ?
|
int avail_sms = (args.avail_sms == -1) ?
|
||||||
device_sms :
|
device_sms :
|
||||||
fast_min(args.sm_limit, device_sms);
|
fast_min(args.avail_sms, device_sms);
|
||||||
|
|
||||||
// Initialize the block mapping structure
|
// Initialize the block mapping structure
|
||||||
block_mapping = ThreadblockSwizzle(
|
block_mapping = ThreadblockSwizzle(
|
||||||
@ -355,7 +365,15 @@ public:
|
|||||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
|
||||||
args.batch_count,
|
args.batch_count,
|
||||||
sm_occupancy,
|
sm_occupancy,
|
||||||
|
device_sms,
|
||||||
avail_sms);
|
avail_sms);
|
||||||
|
|
||||||
|
quick_dp =
|
||||||
|
(block_mapping.sk_waves == 0) &&
|
||||||
|
(mode == GemmUniversalMode::kGemm) &&
|
||||||
|
!block_mapping.cohort_raster &&
|
||||||
|
!EpilogueOutputOp(output_op).is_source_needed();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -426,7 +444,7 @@ public:
|
|||||||
/// Returns the GEMM volume in thread block tiles
|
/// Returns the GEMM volume in thread block tiles
|
||||||
cutlass::gemm::GemmCoord get_tiled_shape() const
|
cutlass::gemm::GemmCoord get_tiled_shape() const
|
||||||
{
|
{
|
||||||
return block_mapping.tiled_shape;
|
return block_mapping.tiled_shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -533,9 +551,6 @@ protected:
|
|||||||
/// ID of each thread within a warp
|
/// ID of each thread within a warp
|
||||||
int lane_idx;
|
int lane_idx;
|
||||||
|
|
||||||
/// Block index
|
|
||||||
int block_idx;
|
|
||||||
|
|
||||||
/// Threadblock scoped epilogue
|
/// Threadblock scoped epilogue
|
||||||
Epilogue epilogue;
|
Epilogue epilogue;
|
||||||
|
|
||||||
@ -640,16 +655,18 @@ protected:
|
|||||||
|
|
||||||
/// Iterator for fetching tile fragments from A
|
/// Iterator for fetching tile fragments from A
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
typename Mma::IteratorA init_iterator_A(TileWorkDesc &tile_work)
|
typename Mma::IteratorA init_iterator_A(
|
||||||
|
TileWorkDesc &tile_work,
|
||||||
|
GemmUniversalMode mode)
|
||||||
{
|
{
|
||||||
// The input A matrix
|
// The input A matrix
|
||||||
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
|
||||||
|
|
||||||
// Update input pointers based on batched/array mode
|
// Update input pointers based on batched/array mode
|
||||||
if (params.mode == GemmUniversalMode::kBatched) {
|
if (mode == GemmUniversalMode::kBatched) {
|
||||||
ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A;
|
ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A;
|
||||||
}
|
}
|
||||||
if (params.mode == GemmUniversalMode::kArray) {
|
if (mode == GemmUniversalMode::kArray) {
|
||||||
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[tile_work.tiled_coord.k()];
|
ptr_A = static_cast<ElementA * const *>(params.ptr_A)[tile_work.tiled_coord.k()];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -667,16 +684,18 @@ protected:
|
|||||||
|
|
||||||
/// Iterator for fetching tile fragments from B
|
/// Iterator for fetching tile fragments from B
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
typename Mma::IteratorB init_iterator_B(TileWorkDesc &tile_work)
|
typename Mma::IteratorB init_iterator_B(
|
||||||
|
TileWorkDesc &tile_work,
|
||||||
|
GemmUniversalMode mode)
|
||||||
{
|
{
|
||||||
// The input B matrix
|
// The input B matrix
|
||||||
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
|
||||||
|
|
||||||
// Update input pointers based on batched/array mode
|
// Update input pointers based on batched/array mode
|
||||||
if (params.mode == GemmUniversalMode::kBatched) {
|
if (mode == GemmUniversalMode::kBatched) {
|
||||||
ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B;
|
ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B;
|
||||||
}
|
}
|
||||||
if (params.mode == GemmUniversalMode::kArray) {
|
if (mode == GemmUniversalMode::kArray) {
|
||||||
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[tile_work.tiled_coord.k()];
|
ptr_B = static_cast<ElementB * const *>(params.ptr_B)[tile_work.tiled_coord.k()];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -700,10 +719,10 @@ protected:
|
|||||||
tile_work.tile_idx = tile_idx;
|
tile_work.tile_idx = tile_idx;
|
||||||
|
|
||||||
// The first global-scoped MAC-iteration this threadblock will perform for this tile
|
// The first global-scoped MAC-iteration this threadblock will perform for this tile
|
||||||
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile;
|
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
|
||||||
|
|
||||||
// The number of MAC-iterations this threadblock will perform for this tile
|
// The number of MAC-iterations this threadblock will perform for this tile
|
||||||
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile;
|
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
|
||||||
|
|
||||||
// The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile
|
// The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile
|
||||||
tile_work.k_begin = 0;
|
tile_work.k_begin = 0;
|
||||||
@ -727,7 +746,7 @@ protected:
|
|||||||
tile_work.tile_idx = tile_idx;
|
tile_work.tile_idx = tile_idx;
|
||||||
|
|
||||||
// The first global-scoped MAC-iteration for this tile
|
// The first global-scoped MAC-iteration for this tile
|
||||||
int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile;
|
int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile();
|
||||||
|
|
||||||
// The first global-scoped MAC-iteration this threadblock will perform for this tile
|
// The first global-scoped MAC-iteration this threadblock will perform for this tile
|
||||||
tile_work.iter_begin = max(block_iter_begin, tile_iter_begin);
|
tile_work.iter_begin = max(block_iter_begin, tile_iter_begin);
|
||||||
@ -756,7 +775,10 @@ protected:
|
|||||||
|
|
||||||
/// Share accumulators with peers
|
/// Share accumulators with peers
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void share_accumulators(AccumulatorTile const &accumulator_tile, int first_block_idx)
|
void share_accumulators(
|
||||||
|
AccumulatorTile const &accumulator_tile,
|
||||||
|
int block_idx,
|
||||||
|
int first_block_idx)
|
||||||
{
|
{
|
||||||
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
|
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
|
||||||
|
|
||||||
@ -795,6 +817,7 @@ protected:
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void acquire_accumulators(
|
void acquire_accumulators(
|
||||||
AccumulatorTile &accumulator_tile,
|
AccumulatorTile &accumulator_tile,
|
||||||
|
int block_idx,
|
||||||
int first_block_idx)
|
int first_block_idx)
|
||||||
{
|
{
|
||||||
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
|
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
|
||||||
@ -868,8 +891,8 @@ protected:
|
|||||||
reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments;
|
reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments;
|
||||||
reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments;
|
reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments;
|
||||||
|
|
||||||
int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile;
|
int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile();
|
||||||
int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile - 1;
|
int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile() - 1;
|
||||||
|
|
||||||
peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first);
|
peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first);
|
||||||
peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last);
|
peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last);
|
||||||
@ -895,16 +918,6 @@ protected:
|
|||||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||||
|
|
||||||
// Update pointers for batched/array mode(s)
|
|
||||||
if (params.mode == GemmUniversalMode::kBatched) {
|
|
||||||
ptr_C += tiled_coord.k() * params.batch_stride_C;
|
|
||||||
ptr_D += tiled_coord.k() * params.batch_stride_D;
|
|
||||||
}
|
|
||||||
if (params.mode == GemmUniversalMode::kArray) {
|
|
||||||
ptr_C = static_cast<ElementC * const *>(params.ptr_C)[tiled_coord.k()];
|
|
||||||
ptr_D = static_cast<ElementC * const *>(params.ptr_D)[tiled_coord.k()];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tile iterator loading from source tensor.
|
// Tile iterator loading from source tensor.
|
||||||
typename Epilogue::OutputTileIterator iterator_C(
|
typename Epilogue::OutputTileIterator iterator_C(
|
||||||
params.params_C,
|
params.params_C,
|
||||||
@ -936,12 +949,13 @@ protected:
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void process_tile(
|
void process_tile(
|
||||||
TileWorkDesc tile_work,
|
TileWorkDesc tile_work,
|
||||||
|
int block_idx,
|
||||||
int dp_start_block_idx,
|
int dp_start_block_idx,
|
||||||
int block_iter_begin)
|
int block_iter_begin)
|
||||||
{
|
{
|
||||||
// Initialize input iterators
|
// Initialize input iterators
|
||||||
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work);
|
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
|
||||||
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work);
|
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);
|
||||||
|
|
||||||
// Initialize accumulators
|
// Initialize accumulators
|
||||||
AccumulatorTile accumulator_tile;
|
AccumulatorTile accumulator_tile;
|
||||||
@ -968,7 +982,7 @@ protected:
|
|||||||
|
|
||||||
if (!tile_work.tile_finished(params)) {
|
if (!tile_work.tile_finished(params)) {
|
||||||
// Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace
|
// Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace
|
||||||
share_accumulators(accumulator_tile, first_block_idx);
|
share_accumulators(accumulator_tile, block_idx, first_block_idx);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -976,7 +990,7 @@ protected:
|
|||||||
if (!tile_work.tile_started())
|
if (!tile_work.tile_started())
|
||||||
{
|
{
|
||||||
// A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks
|
// A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks
|
||||||
acquire_accumulators(accumulator_tile, first_block_idx);
|
acquire_accumulators(accumulator_tile, block_idx, first_block_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
do_epilogue(tile_work, accumulator_tile);
|
do_epilogue(tile_work, accumulator_tile);
|
||||||
@ -1008,11 +1022,12 @@ protected:
|
|||||||
// Initialize block's iteration range
|
// Initialize block's iteration range
|
||||||
int tile_idx, block_iter_begin, block_iters_remaining;
|
int tile_idx, block_iter_begin, block_iters_remaining;
|
||||||
|
|
||||||
int sk_padding_start_block_idx = params.block_mapping.sk_regions * params.block_mapping.sk_blocks_per_region;
|
int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();
|
||||||
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
|
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
|
||||||
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
|
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
|
||||||
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;
|
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;
|
||||||
|
|
||||||
|
int block_idx = params.block_mapping.get_block_idx();
|
||||||
if (block_idx < sk_padding_start_block_idx)
|
if (block_idx < sk_padding_start_block_idx)
|
||||||
{
|
{
|
||||||
// This is a SK block
|
// This is a SK block
|
||||||
@ -1044,8 +1059,9 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
block_iter_begin = 0;
|
block_iter_begin = 0;
|
||||||
block_iters_remaining = params.block_mapping.iters_per_tile * tile_allottment;
|
block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;
|
||||||
}
|
}
|
||||||
|
|
||||||
else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) &&
|
else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) &&
|
||||||
(block_idx < grid_padding_start_block_idx))
|
(block_idx < grid_padding_start_block_idx))
|
||||||
{
|
{
|
||||||
@ -1072,8 +1088,8 @@ protected:
|
|||||||
|
|
||||||
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
|
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
|
||||||
if ((tile_idx < params.block_mapping.sk_tiles) ||
|
if ((tile_idx < params.block_mapping.sk_tiles) ||
|
||||||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape.m()) ||
|
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
|
||||||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape.n()))
|
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -1084,7 +1100,7 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Perform this block's share of work for this tile
|
// Perform this block's share of work for this tile
|
||||||
process_tile(tile_work, dp_start_block_idx, block_iter_begin);
|
process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin);
|
||||||
|
|
||||||
// Update remaining work for this block
|
// Update remaining work for this block
|
||||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||||
@ -1110,6 +1126,64 @@ protected:
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Executes one DP-only GEMM
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void gemm_dp()
|
||||||
|
{
|
||||||
|
int block_idx = blockIdx.x;
|
||||||
|
int tile_idx = block_idx;
|
||||||
|
|
||||||
|
TileWorkDesc tile_work;
|
||||||
|
tile_work.tile_idx = tile_idx;
|
||||||
|
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
|
||||||
|
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
|
||||||
|
tile_work.k_begin = 0;
|
||||||
|
tile_work.k_end = params.block_mapping.problem_size.k();
|
||||||
|
tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx);
|
||||||
|
|
||||||
|
// Initialize input iterators
|
||||||
|
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
|
||||||
|
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);
|
||||||
|
|
||||||
|
// Initialize accumulators
|
||||||
|
AccumulatorTile accumulator_tile;
|
||||||
|
accumulator_tile.clear();
|
||||||
|
|
||||||
|
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
||||||
|
Mma mma(
|
||||||
|
shared_storage.main_loop,
|
||||||
|
thread_idx,
|
||||||
|
warp_idx,
|
||||||
|
lane_idx);
|
||||||
|
|
||||||
|
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
|
||||||
|
|
||||||
|
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||||
|
|
||||||
|
// Location of this tile in item-coords
|
||||||
|
MatrixCoord threadblock_item_begin(
|
||||||
|
tile_work.tiled_coord.m() * Mma::Shape::kM,
|
||||||
|
tile_work.tiled_coord.n() * Mma::Shape::kN
|
||||||
|
);
|
||||||
|
|
||||||
|
// Tile iterator writing to destination tensor.
|
||||||
|
typename Epilogue::OutputTileIterator iterator_D(
|
||||||
|
params.params_D,
|
||||||
|
ptr_D,
|
||||||
|
params.block_mapping.problem_size.mn(),
|
||||||
|
thread_idx,
|
||||||
|
threadblock_item_begin);
|
||||||
|
|
||||||
|
// Execute the epilogue operator to update the destination tensor.
|
||||||
|
epilogue(
|
||||||
|
EpilogueOutputOp(params.output_op),
|
||||||
|
iterator_D,
|
||||||
|
accumulator_tile);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -1138,7 +1212,6 @@ public:
|
|||||||
thread_idx(threadIdx.x),
|
thread_idx(threadIdx.x),
|
||||||
warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code
|
warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||||
lane_idx(threadIdx.x % 32),
|
lane_idx(threadIdx.x % 32),
|
||||||
block_idx(params.block_mapping.get_block_idx()),
|
|
||||||
epilogue(
|
epilogue(
|
||||||
shared_storage.epilogue,
|
shared_storage.epilogue,
|
||||||
thread_idx,
|
thread_idx,
|
||||||
@ -1151,7 +1224,17 @@ public:
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void operator()()
|
void operator()()
|
||||||
{
|
{
|
||||||
// Do the GEMM
|
#if (__CUDACC_VER_MAJOR__ > 10)
|
||||||
|
if (params.quick_dp)
|
||||||
|
{
|
||||||
|
// Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray
|
||||||
|
// modes will only be launched using a data-parallel configurations)
|
||||||
|
gemm_dp();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Generic SK code path
|
||||||
gemm();
|
gemm();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -115,28 +115,21 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
// Member state
|
// Member state
|
||||||
//
|
//
|
||||||
|
|
||||||
|
|
||||||
/// The 3D value-extents of the GEMM computation volume (m,n,k)
|
/// The 3D value-extents of the GEMM computation volume (m,n,k)
|
||||||
GemmCoord problem_size;
|
GemmCoord problem_size;
|
||||||
|
|
||||||
/// The 2D tile-extents of the output matrix (m,n)
|
/// Div/mod accelerators
|
||||||
GemmCoord tiled_shape;
|
FastDivmod div_mod_tiled_shape_m;
|
||||||
|
FastDivmod div_mod_tiled_shape_n;
|
||||||
|
FastDivmod div_mod_tiled_cohort_shape_n;
|
||||||
|
FastDivmod div_mod_iters_per_tile;
|
||||||
|
|
||||||
/// Number of iterations per output tile
|
/// Whether to perform cohort CTA rasterization
|
||||||
int iters_per_tile;
|
bool cohort_raster;
|
||||||
|
|
||||||
/// Number of reduction blocks in the grid
|
// Whether to pad and remap block indices
|
||||||
int reduction_blocks;
|
bool remap_block_indices;
|
||||||
|
|
||||||
int dp_blocks; /// Number of data-parallel thread blocks in the grid
|
|
||||||
int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce
|
|
||||||
|
|
||||||
int sk_tiles;
|
|
||||||
int sk_regions;
|
|
||||||
int sk_blocks_per_region;
|
|
||||||
int sk_big_blocks_per_region;
|
|
||||||
int sk_iters_per_region;
|
|
||||||
int sk_iters_per_normal_block; /// Number of iterations for normal SK-blocks
|
|
||||||
int sk_waves; /// Number of SK waves in the grid
|
|
||||||
|
|
||||||
/// CTA occupancy per SM
|
/// CTA occupancy per SM
|
||||||
int sm_occupancy;
|
int sm_occupancy;
|
||||||
@ -144,21 +137,26 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
/// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size)
|
/// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size)
|
||||||
int avail_sms;
|
int avail_sms;
|
||||||
|
|
||||||
/// Whether to perform cohort CTA rasterization
|
int dp_blocks; /// Number of data-parallel thread blocks in the grid
|
||||||
bool cohort_raster;
|
int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce
|
||||||
|
|
||||||
|
/// Number of reduction blocks in the grid
|
||||||
|
int reduction_blocks;
|
||||||
|
|
||||||
|
int sk_waves;
|
||||||
|
int sk_tiles;
|
||||||
|
int sk_big_blocks_per_region;
|
||||||
|
int sk_iters_per_region;
|
||||||
|
|
||||||
/// Div/mod accelerators
|
/// Div/mod accelerators
|
||||||
struct
|
FastDivmod div_mod_sk_iters_per_normal_block;
|
||||||
{
|
FastDivmod div_mod_sk_iters_per_big_block;
|
||||||
FastDivmod tiled_shape_m;
|
FastDivmod div_mod_sk_iters_per_region;
|
||||||
FastDivmod tiled_shape_n;
|
FastDivmod div_mod_sk_regions; //!! used in block map
|
||||||
FastDivmod tiled_cohort_shape_n;
|
FastDivmod div_mod_sk_blocks_per_region; //!! used in block map
|
||||||
FastDivmod iters_per_tile;
|
|
||||||
FastDivmod sk_iters_per_normal_block;
|
/// The batch count
|
||||||
FastDivmod sk_iters_per_big_block;
|
int batch_count;
|
||||||
FastDivmod sk_iters_per_region;
|
|
||||||
FastDivmod sk_blocks_per_region;
|
|
||||||
} div_mod;
|
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -169,6 +167,43 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
ThreadblockSwizzleStreamK() {}
|
ThreadblockSwizzleStreamK() {}
|
||||||
|
|
||||||
|
/// Returns the GEMM volume in thread block tiles
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
GemmCoord tiled_shape() const
|
||||||
|
{
|
||||||
|
return GemmCoord(
|
||||||
|
static_cast<int>(div_mod_tiled_shape_m),
|
||||||
|
static_cast<int>(div_mod_tiled_shape_n),
|
||||||
|
batch_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of iterations per output tile
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
int iters_per_tile() const
|
||||||
|
{
|
||||||
|
return static_cast<int>(div_mod_iters_per_tile);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of iterations for normal SK-blocks
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
int sk_iters_per_normal_block() const
|
||||||
|
{
|
||||||
|
return static_cast<int>(div_mod_sk_iters_per_normal_block);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of SK regions
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
int sk_regions() const
|
||||||
|
{
|
||||||
|
return static_cast<int>(div_mod_sk_regions);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of SK blocks per region (splitting factor)
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
int sk_blocks_per_region() const
|
||||||
|
{
|
||||||
|
return static_cast<int>(div_mod_sk_blocks_per_region);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -179,26 +214,27 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
void Print()
|
void Print()
|
||||||
{
|
{
|
||||||
#ifndef __CUDA_ARCH__
|
#ifndef __CUDA_ARCH__
|
||||||
int tiles = tiled_shape.m() * tiled_shape.n();
|
auto tiles = tiled_shape().mn().product();
|
||||||
std::cout <<
|
std::cout <<
|
||||||
"problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" <<
|
"problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" <<
|
||||||
", reduction_blocks: " << reduction_blocks <<
|
", tiled_shape: (" << tiled_shape().m() << "," << tiled_shape().n() << ")" <<
|
||||||
", dp_blocks: " << dp_blocks <<
|
|
||||||
", sk_blocks_per_region: " << sk_blocks_per_region <<
|
|
||||||
", sk_regions: " << sk_regions <<
|
|
||||||
", sk_waves: " << sk_waves <<
|
|
||||||
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
|
|
||||||
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
|
|
||||||
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
|
|
||||||
", tiled_shape: (" << tiled_shape.m() << "," << tiled_shape.n() << ")" <<
|
|
||||||
", tiles: " << tiles <<
|
", tiles: " << tiles <<
|
||||||
", iters_per_tile: " << iters_per_tile <<
|
|
||||||
", dp_tiles: " << tiles - sk_tiles <<
|
", dp_tiles: " << tiles - sk_tiles <<
|
||||||
", sk_tiles: " << sk_tiles <<
|
", sk_tiles: " << sk_tiles <<
|
||||||
", avail_sms: " << avail_sms <<
|
", iters_per_tile: " << iters_per_tile() <<
|
||||||
|
", reduction_blocks: " << reduction_blocks <<
|
||||||
|
", dp_blocks: " << dp_blocks <<
|
||||||
|
", dp_waves: " << dp_blocks / avail_sms <<
|
||||||
|
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
|
||||||
|
", sk_blocks_per_region: " << sk_blocks_per_region() <<
|
||||||
|
", sk_regions: " << sk_regions() <<
|
||||||
|
", sk_waves: " << sk_waves <<
|
||||||
|
", sk_iters_per_normal_block: " << sk_iters_per_normal_block() <<
|
||||||
|
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
|
||||||
|
", remap_block_indices: " << remap_block_indices <<
|
||||||
|
", cohort_raster: " << cohort_raster <<
|
||||||
", sm_occupancy: " << sm_occupancy <<
|
", sm_occupancy: " << sm_occupancy <<
|
||||||
", avail_sms: " << avail_sms <<
|
", avail_sms: " << avail_sms <<
|
||||||
", cohort_raster: " << cohort_raster <<
|
|
||||||
", num_blocks: " << get_num_blocks() <<
|
", num_blocks: " << get_num_blocks() <<
|
||||||
"\n\n";
|
"\n\n";
|
||||||
#endif
|
#endif
|
||||||
@ -368,30 +404,37 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
GemmUniversalMode const mode_,
|
GemmUniversalMode const mode_,
|
||||||
GemmCoord const problem_size_,
|
GemmCoord const problem_size_,
|
||||||
GemmCoord const tile_size_,
|
GemmCoord const tile_size_,
|
||||||
int const batch_count_, /// Batch count (when mode_ == GemmUniversalMode::kBatched) or split-K-override splitting factor (when mode_ == GemmUniversalMode::kGemm)
|
int const batch_split_, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K)
|
||||||
int const sm_occupancy_,
|
int const sm_occupancy_,
|
||||||
int const avail_sms_)
|
int const device_sms_,
|
||||||
|
int const avail_sms_) /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling)
|
||||||
:
|
:
|
||||||
problem_size(problem_size_),
|
problem_size(problem_size_),
|
||||||
tiled_shape(
|
batch_count((mode_ == GemmUniversalMode::kBatched) ? batch_split_ : 1),
|
||||||
(problem_size.m() + tile_size_.m() - 1) / tile_size_.m(),
|
|
||||||
(problem_size.n() + tile_size_.n() - 1) / tile_size_.n(),
|
|
||||||
(mode_ == GemmUniversalMode::kBatched) ? batch_count_ : 1),
|
|
||||||
iters_per_tile((problem_size.k() + tile_size_.k() - 1) / tile_size_.k()),
|
|
||||||
reduction_blocks(0),
|
reduction_blocks(0),
|
||||||
dp_blocks(0),
|
dp_blocks(0),
|
||||||
dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks
|
dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks
|
||||||
sk_tiles(0),
|
sk_tiles(0),
|
||||||
sk_regions(1), // Default: a single region of iteration space (across all SK tiles)
|
|
||||||
sk_blocks_per_region(0),
|
|
||||||
sk_big_blocks_per_region(0),
|
sk_big_blocks_per_region(0),
|
||||||
sk_iters_per_region(0),
|
sk_iters_per_region(0),
|
||||||
sk_iters_per_normal_block(0),
|
|
||||||
sk_waves(0),
|
sk_waves(0),
|
||||||
sm_occupancy(sm_occupancy_),
|
sm_occupancy(sm_occupancy_),
|
||||||
|
remap_block_indices(false),
|
||||||
avail_sms(fast_max(1, avail_sms_)),
|
avail_sms(fast_max(1, avail_sms_)),
|
||||||
cohort_raster(false)
|
cohort_raster(false)
|
||||||
{
|
{
|
||||||
|
int gpu_occupancy = device_sms_ * sm_occupancy;
|
||||||
|
int iters_per_tile = (problem_size.k() + tile_size_.k() - 1) / tile_size_.k();
|
||||||
|
int sk_iters_per_normal_block = 0;
|
||||||
|
|
||||||
|
int sk_regions = 1; // Default: a single region of iteration space (across all SK tiles)
|
||||||
|
int sk_blocks_per_region = 0;
|
||||||
|
|
||||||
|
GemmCoord tiled_shape(
|
||||||
|
(problem_size.m() + tile_size_.m() - 1) / tile_size_.m(),
|
||||||
|
(problem_size.n() + tile_size_.n() - 1) / tile_size_.n(),
|
||||||
|
batch_count);
|
||||||
|
|
||||||
size_t problem_bytes =
|
size_t problem_bytes =
|
||||||
(sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) +
|
(sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) +
|
||||||
(sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) +
|
(sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) +
|
||||||
@ -401,7 +444,6 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
|
|
||||||
float flops_per_byte = float(problem_flops) / float(problem_bytes);
|
float flops_per_byte = float(problem_flops) / float(problem_bytes);
|
||||||
|
|
||||||
int gpu_occupancy = avail_sms * sm_occupancy;
|
|
||||||
int output_tiles = tiled_shape.m() * tiled_shape.n();
|
int output_tiles = tiled_shape.m() * tiled_shape.n();
|
||||||
int waves = (output_tiles + avail_sms - 1) / avail_sms;
|
int waves = (output_tiles + avail_sms - 1) / avail_sms;
|
||||||
float dp_efficiency = float(output_tiles) / float(waves * avail_sms);
|
float dp_efficiency = float(output_tiles) / float(waves * avail_sms);
|
||||||
@ -414,14 +456,15 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
int dp_tiles = output_tiles; // Number of data-parallel tiles
|
int dp_tiles = output_tiles; // Number of data-parallel tiles
|
||||||
int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles
|
int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles
|
||||||
|
|
||||||
// kGemm mode allows for SK load balancing
|
// Only kGemm mode allows for SK load balancing
|
||||||
if (mode_ == GemmUniversalMode::kGemm)
|
if (mode_ == GemmUniversalMode::kGemm)
|
||||||
{
|
{
|
||||||
if (batch_count_ > 1)
|
int split_factor = batch_split_;
|
||||||
|
if (split_factor > 1)
|
||||||
{
|
{
|
||||||
// Split-K override
|
// Split-K override
|
||||||
dp_tiles = 0;
|
dp_tiles = 0;
|
||||||
sk_blocks = output_tiles * batch_count_;
|
sk_blocks = output_tiles * split_factor;
|
||||||
}
|
}
|
||||||
else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled
|
else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled
|
||||||
(avail_sms > 1)) // Plurality of SMs to load balance across
|
(avail_sms > 1)) // Plurality of SMs to load balance across
|
||||||
@ -462,24 +505,39 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
sk_big_blocks_per_region = sk_big_blocks / sk_regions;
|
sk_big_blocks_per_region = sk_big_blocks / sk_regions;
|
||||||
sk_iters_per_region = sk_iters / sk_regions;
|
sk_iters_per_region = sk_iters / sk_regions;
|
||||||
|
|
||||||
div_mod.sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block);
|
// Use a separate reduction wave when all of:
|
||||||
div_mod.sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1);
|
// - Non-atomic reduction stratgy
|
||||||
div_mod.sk_iters_per_region = FastDivmod(sk_iters_per_region);
|
// - The number of SK waves won't fully occupy the GPU (Otherwise we don't have
|
||||||
div_mod.sk_blocks_per_region = FastDivmod(sk_blocks_per_region);
|
// a strong-scaling case for more parallel reduction)
|
||||||
|
// - More than three peers working on an SK tile. (This occurs when the ratio of
|
||||||
// Separate reduction heuristic
|
// SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks,
|
||||||
if ((kReductionStrategy == kMixed) &&
|
|
||||||
(sk_blocks > 2 * sk_tiles)) // Use a separate reduction wave whenever we would have more than three
|
|
||||||
// peers working on an SK tile. (This occurs when the ratio of SK-blocks
|
|
||||||
// to SK-tiles > 2, as a single tile may be covered by four SK-blocks,
|
|
||||||
// e.g.:[partial-block | block | block | partial-block] ). With three or
|
// e.g.:[partial-block | block | block | partial-block] ). With three or
|
||||||
// less peers, the two non-finishing SK-blocks are not expexted to contend.
|
// less peers, the two non-finishing SK-blocks are not expexted to contend.
|
||||||
|
if ((kReductionStrategy == kMixed) &&
|
||||||
|
(sk_waves < sm_occupancy) &&
|
||||||
|
(sk_blocks > 2 * sk_tiles))
|
||||||
{
|
{
|
||||||
// Launch a reduction block every accumulator fragment in each SK-tile
|
// Launch a reduction block for every accumulator fragment in each SK-tile
|
||||||
static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments;
|
static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments;
|
||||||
reduction_blocks = sk_tiles * kAccumulatorFragments;
|
reduction_blocks = sk_tiles * kAccumulatorFragments;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When we have a multi-occupancy kernel and at least two waves of active blocks (where
|
||||||
|
// at least one wave is SK blocks), we need to (1) dispatch at least four waves, and (2)
|
||||||
|
// remap the block indices so that we can reliably spread the SK blocks evenly across the
|
||||||
|
// device's first SM occupancy valence. Also see get_num_blocks() and get_block_idx().
|
||||||
|
remap_block_indices = (
|
||||||
|
(sm_occupancy > 1) &&
|
||||||
|
(device_sms_ == avail_sms) &&
|
||||||
|
(get_num_active_blocks() > avail_sms * 2));
|
||||||
|
|
||||||
|
// Initialize fast div/mod members related to SK
|
||||||
|
div_mod_sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block);
|
||||||
|
div_mod_sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1);
|
||||||
|
div_mod_sk_iters_per_region = FastDivmod(sk_iters_per_region);
|
||||||
|
div_mod_sk_regions = FastDivmod(sk_regions);
|
||||||
|
div_mod_sk_blocks_per_region = FastDivmod(sk_blocks_per_region);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -491,7 +549,7 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
cutlass::gemm::GemmCoord tiled_cohort_shape(
|
cutlass::gemm::GemmCoord tiled_cohort_shape(
|
||||||
(tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM,
|
(tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM,
|
||||||
(tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN,
|
(tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN,
|
||||||
batch_count_);
|
tiled_shape.k());
|
||||||
int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort;
|
int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort;
|
||||||
float cohort_efficiency = float(dp_blocks) / float(cohort_blocks);
|
float cohort_efficiency = float(dp_blocks) / float(cohort_blocks);
|
||||||
|
|
||||||
@ -511,11 +569,12 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
{
|
{
|
||||||
sk_in_range = false;
|
sk_in_range = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decide if we're going to be doing cohort raster
|
// Decide if we're going to be doing cohort raster
|
||||||
if (sk_in_range &&
|
if (sk_in_range &&
|
||||||
(dp_blocks >= gpu_occupancy) &&
|
(dp_blocks >= gpu_occupancy * 2) &&
|
||||||
(cohort_efficiency > 0.85f))
|
(cohort_efficiency > 0.85f))
|
||||||
{
|
{
|
||||||
cohort_raster = true;
|
cohort_raster = true;
|
||||||
@ -537,92 +596,54 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setup fast-div/mod for device-side usage
|
// Setup fast-div/mod for device-side usage
|
||||||
div_mod.tiled_shape_m = FastDivmod(tiled_shape.m());
|
div_mod_tiled_shape_m = FastDivmod(tiled_shape.m());
|
||||||
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
|
div_mod_tiled_shape_n = FastDivmod(tiled_shape.n());
|
||||||
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
|
div_mod_tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
|
||||||
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
|
div_mod_iters_per_tile = FastDivmod(iters_per_tile);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Number of blocks performing useful work
|
||||||
/// Constructor: *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC)
|
int get_num_active_blocks() const
|
||||||
template <typename GemmKernel>
|
{
|
||||||
ThreadblockSwizzleStreamK(
|
return (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
|
||||||
KernelTraits<GemmKernel> kernel_traits_,
|
}
|
||||||
GemmUniversalMode mode_,
|
|
||||||
cutlass::conv::Operator conv_operator,
|
|
||||||
cutlass::conv::Conv2dProblemSize const &problem_size_,
|
|
||||||
GemmCoord tile_size_,
|
|
||||||
int batch_count_,
|
|
||||||
int sm_occupancy_,
|
|
||||||
int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
|
||||||
int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs
|
|
||||||
int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles
|
|
||||||
:
|
|
||||||
ThreadblockSwizzleStreamK(
|
|
||||||
kernel_traits_,
|
|
||||||
mode_,
|
|
||||||
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_),
|
|
||||||
tile_size_,
|
|
||||||
batch_count_,
|
|
||||||
sm_occupancy_,
|
|
||||||
avail_sms_,
|
|
||||||
dp_tiles_,
|
|
||||||
sk_blocks_)
|
|
||||||
{}
|
|
||||||
|
|
||||||
|
|
||||||
/// Constructor: *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC)
|
|
||||||
template <typename GemmKernel>
|
|
||||||
ThreadblockSwizzleStreamK(
|
|
||||||
KernelTraits<GemmKernel> kernel_traits_,
|
|
||||||
GemmUniversalMode mode_,
|
|
||||||
cutlass::conv::Operator conv_operator,
|
|
||||||
cutlass::conv::Conv3dProblemSize const &problem_size_,
|
|
||||||
GemmCoord tile_size_,
|
|
||||||
int batch_count_,
|
|
||||||
int sm_occupancy_,
|
|
||||||
int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
|
|
||||||
int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs
|
|
||||||
int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles
|
|
||||||
:
|
|
||||||
ThreadblockSwizzleStreamK(
|
|
||||||
kernel_traits_,
|
|
||||||
mode_,
|
|
||||||
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_),
|
|
||||||
tile_size_,
|
|
||||||
batch_count_,
|
|
||||||
sm_occupancy_,
|
|
||||||
avail_sms_,
|
|
||||||
dp_tiles_,
|
|
||||||
sk_blocks_)
|
|
||||||
{}
|
|
||||||
|
|
||||||
|
|
||||||
/// Obtains number of threadblocks per GEMM
|
/// Obtains number of threadblocks per GEMM
|
||||||
int get_num_blocks() const
|
int get_num_blocks() const
|
||||||
{
|
{
|
||||||
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
|
int active_blocks = get_num_active_blocks();
|
||||||
|
if (remap_block_indices)
|
||||||
if (work_blocks <= avail_sms * 2)
|
|
||||||
{
|
{
|
||||||
return work_blocks;
|
// Add padding blocks if we are performing remapping in order to dispatch a grid of at least four waves
|
||||||
|
return fast_max(active_blocks, avail_sms * 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
return fast_max(work_blocks, avail_sms * 4);
|
return active_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Obtains grid extents in CTAs
|
/// Obtains grid extents in CTAs
|
||||||
dim3 get_grid_dims() const
|
dim3 get_grid_dims() const
|
||||||
{
|
{
|
||||||
return dim3(get_num_blocks(), 1, tiled_shape.k());
|
return dim3(get_num_blocks(), 1, batch_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Guards needed for PyCUTLASS library generation
|
||||||
|
#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||||
|
|
||||||
//
|
//
|
||||||
// Device-side interface
|
// Device-side interface
|
||||||
//
|
//
|
||||||
|
|
||||||
|
/// Proves to the compiler that val is warp-uniform
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int uniform(int val) const
|
||||||
|
{
|
||||||
|
return __shfl_sync(0xffffffff, val, 0);
|
||||||
|
}
|
||||||
|
|
||||||
/// Obtains number of threadblocks per GEMM
|
/// Obtains number of threadblocks per GEMM
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
int device_num_blocks() const
|
int device_num_blocks() const
|
||||||
@ -634,9 +655,16 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
int get_sk_tile_idx(int iter) const
|
int get_sk_tile_idx(int iter) const
|
||||||
{
|
{
|
||||||
return div_mod.iters_per_tile.div(iter);
|
int tile_idx = div_mod_iters_per_tile.div(iter);
|
||||||
|
return uniform(tile_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Obtains the batch index
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
int get_batch_idx() const
|
||||||
|
{
|
||||||
|
return RematerializeBlockIdxZ();
|
||||||
|
}
|
||||||
|
|
||||||
/// Obtains the calling threadblock's tiled coordinates for the given tile index
|
/// Obtains the calling threadblock's tiled coordinates for the given tile index
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
@ -644,12 +672,21 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
{
|
{
|
||||||
int m, n;
|
int m, n;
|
||||||
|
|
||||||
|
// row-major raster
|
||||||
|
div_mod_tiled_shape_n(m, n, tile_idx);
|
||||||
|
|
||||||
|
if (tiled_shape().m() < tiled_shape().n())
|
||||||
|
{
|
||||||
|
// column-major raster
|
||||||
|
div_mod_tiled_shape_m(n, m, tile_idx);
|
||||||
|
}
|
||||||
|
|
||||||
if (cohort_raster)
|
if (cohort_raster)
|
||||||
{
|
{
|
||||||
// tiled cohort raster
|
// tiled cohort raster
|
||||||
int cohort_tile_idx = tile_idx / kCtasPerCohort;
|
int cohort_tile_idx = tile_idx / kCtasPerCohort;
|
||||||
int cohort_grid_m, cohort_grid_n;
|
int cohort_grid_m, cohort_grid_n;
|
||||||
div_mod.tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx);
|
div_mod_tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx);
|
||||||
|
|
||||||
int block_idx_cohort = tile_idx % kCtasPerCohort;
|
int block_idx_cohort = tile_idx % kCtasPerCohort;
|
||||||
int block_cohort_m = block_idx_cohort / kCohortCtasN;
|
int block_cohort_m = block_idx_cohort / kCohortCtasN;
|
||||||
@ -658,44 +695,46 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
m = (cohort_grid_m * kCohortCtasM) + block_cohort_m;
|
m = (cohort_grid_m * kCohortCtasM) + block_cohort_m;
|
||||||
n = (cohort_grid_n * kCohortCtasN) + block_cohort_n;
|
n = (cohort_grid_n * kCohortCtasN) + block_cohort_n;
|
||||||
}
|
}
|
||||||
else if (tiled_shape.m() < tiled_shape.n())
|
|
||||||
{
|
return GemmCoord(m, n, get_batch_idx());
|
||||||
// column-major raster
|
|
||||||
div_mod.tiled_shape_m(n, m, tile_idx);
|
|
||||||
}
|
}
|
||||||
else
|
|
||||||
|
/// Obtains the calling threadblock's tiled coordinates for the given tile index (row-major rastorization)
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
GemmCoord get_tile_offset_row_major(int tile_idx) const
|
||||||
{
|
{
|
||||||
// row-major raster
|
// row-major raster
|
||||||
div_mod.tiled_shape_n(m, n, tile_idx);
|
int m, n;
|
||||||
|
div_mod_tiled_shape_n(m, n, tile_idx);
|
||||||
|
return GemmCoord(m, n, get_batch_idx());
|
||||||
}
|
}
|
||||||
|
|
||||||
int block_idx_k = RematerializeBlockIdxZ();
|
|
||||||
return GemmCoord{m, n, block_idx_k};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/// Obtains calling threadblock's linear threadblock index
|
/// Obtains calling threadblock's linear threadblock index
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
int get_block_idx() const
|
int get_block_idx() const
|
||||||
{
|
{
|
||||||
|
int block_idx = RematerializeBlockIdxX();
|
||||||
|
|
||||||
// Remap the block indices for the first two waves of thread blocks if
|
// Remap the block indices for the first two waves of thread blocks if
|
||||||
// we have multi-occupancy and the grid constitutes four or more waves
|
// we have multi-occupancy and the grid constitutes four or more waves
|
||||||
|
if (remap_block_indices && (block_idx < avail_sms * 2))
|
||||||
int block_idx = RematerializeBlockIdxX();
|
{
|
||||||
int num_blocks = device_num_blocks();
|
|
||||||
int dest_sm = block_idx / 2;
|
int dest_sm = block_idx / 2;
|
||||||
int dest_wave = block_idx % 2;
|
int dest_wave = block_idx % 2;
|
||||||
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
|
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
|
||||||
|
|
||||||
if ((sm_occupancy > 1) &&
|
|
||||||
(num_blocks >= avail_sms * 4) &&
|
|
||||||
(block_idx < avail_sms * 2))
|
|
||||||
{
|
|
||||||
block_idx = remapped_block_idx;
|
block_idx = remapped_block_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Block-index is blockIdx.x for DP blocks
|
// Remap block indices to interleave SK regions to limit intra-region waiting
|
||||||
return block_idx;
|
if (block_idx < sk_regions() * sk_blocks_per_region())
|
||||||
|
{
|
||||||
|
int block_in_region;
|
||||||
|
int region;
|
||||||
|
div_mod_sk_regions(block_in_region, region, block_idx);
|
||||||
|
block_idx = (region * sk_blocks_per_region()) + block_in_region;
|
||||||
|
}
|
||||||
|
|
||||||
|
return uniform(block_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -705,19 +744,21 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
{
|
{
|
||||||
int region_idx;
|
int region_idx;
|
||||||
int iter_in_region;
|
int iter_in_region;
|
||||||
div_mod.sk_iters_per_region(region_idx, iter_in_region, iter);
|
div_mod_sk_iters_per_region(region_idx, iter_in_region, iter);
|
||||||
|
|
||||||
int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block) + sk_big_blocks_per_region; // number of iterations in the region's big blocks
|
int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block()) + sk_big_blocks_per_region; // number of iterations in the region's big blocks
|
||||||
int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal bocks
|
int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal bocks
|
||||||
|
|
||||||
int big_block_idx_in_region = div_mod.sk_iters_per_big_block.div(iter_in_region);
|
int big_block_idx_in_region = div_mod_sk_iters_per_big_block.div(iter_in_region);
|
||||||
int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod.sk_iters_per_normal_block.div(normal_block_iters);
|
int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod_sk_iters_per_normal_block.div(normal_block_iters);
|
||||||
|
|
||||||
int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ?
|
int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ?
|
||||||
big_block_idx_in_region :
|
big_block_idx_in_region :
|
||||||
normal_block_idx_in_region;
|
normal_block_idx_in_region;
|
||||||
|
|
||||||
return (sk_blocks_per_region * region_idx) + block_idx_in_region;
|
int owning_block_idx = (sk_blocks_per_region() * region_idx) + block_idx_in_region;
|
||||||
|
|
||||||
|
return owning_block_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Obtains iteration extends for the given SK block index
|
/// Obtains iteration extends for the given SK block index
|
||||||
@ -729,12 +770,12 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
{
|
{
|
||||||
int region_idx;
|
int region_idx;
|
||||||
int block_idx_in_region;
|
int block_idx_in_region;
|
||||||
div_mod.sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx);
|
div_mod_sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx);
|
||||||
|
|
||||||
block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block);
|
block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block());
|
||||||
|
|
||||||
// Adjust extents for the first "num_big_blocks" blocks that get one extra iteration
|
// Adjust extents for the first "num_big_blocks" blocks that get one extra iteration
|
||||||
int block_iters = sk_iters_per_normal_block;
|
int block_iters = sk_iters_per_normal_block();
|
||||||
if (block_idx_in_region < sk_big_blocks_per_region) {
|
if (block_idx_in_region < sk_big_blocks_per_region) {
|
||||||
// This is a +1 iteration block
|
// This is a +1 iteration block
|
||||||
block_iter_begin += block_idx_in_region;
|
block_iter_begin += block_idx_in_region;
|
||||||
@ -756,10 +797,12 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
return block_idx;
|
return block_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
int iter = tile_idx * iters_per_tile;
|
int iter = tile_idx * iters_per_tile();
|
||||||
return get_sk_block_idx(iter);
|
return get_sk_block_idx(iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif // defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__)
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -50,9 +50,9 @@ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8):
|
|||||||
#
|
#
|
||||||
def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
|
def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
|
||||||
alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
|
alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
|
||||||
swizzling_functor = SwizzlingFunctor.Identity8):
|
# swizzling_functor = SwizzlingFunctor.Identity8):
|
||||||
# Use StreamK decomposition for basic GEMMs
|
# Use StreamK decomposition for basic GEMMs
|
||||||
# swizzling_functor = SwizzlingFunctor.StreamK):
|
swizzling_functor = SwizzlingFunctor.StreamK):
|
||||||
|
|
||||||
if complex_transforms is None:
|
if complex_transforms is None:
|
||||||
complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
|
complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
|
||||||
@ -4600,6 +4600,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
|
parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
|
||||||
help='Specify the output log file containing all enabled kernels in this build')
|
help='Specify the output log file containing all enabled kernels in this build')
|
||||||
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
|
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
|
||||||
|
parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user