diff --git a/examples/13_two_tensor_op_fusion/CMakeLists.txt b/examples/13_two_tensor_op_fusion/CMakeLists.txt index b97205ba..006819c0 100644 --- a/examples/13_two_tensor_op_fusion/CMakeLists.txt +++ b/examples/13_two_tensor_op_fusion/CMakeLists.txt @@ -64,6 +64,7 @@ endforeach() foreach(FUSION_GEMM_EXAMPLE fused_two_gemms_f16_sm75_rf fused_two_gemms_f16_sm75_shmem + fused_two_gemms_grouped_f16_sm80_rf fused_two_gemms_f16_sm80_rf fused_two_gemms_f16_sm80_shmem fused_two_gemms_s8_sm75_rf diff --git a/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h new file mode 100644 index 00000000..267423d4 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Containers for running grouped back-to-back GEMMs +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct B2bFusedGroupedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bFusedGroupedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + /// Executes one test + bool run( + std::vector problem_sizes_0, + std::vector problem_sizes_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + using HostTensorA = cutlass::HostTensor; + using HostTensorB = cutlass::HostTensor; + using HostTensorC = cutlass::HostTensor; + using HostTensorScale = cutlass::HostTensor; + using HostTensorZ = cutlass::HostTensor; + using HostTensorBias = cutlass::HostTensor; + + int problem_count = (int)problem_sizes_0.size(); + + std::vector host_tensor_A0(problem_count); + std::vector host_tensor_B0(problem_count); + std::vector host_tensor_C0(problem_count); + std::vector host_tensor_Scale0(problem_count); + std::vector host_tensor_Bias0(problem_count); + std::vector host_tensor_B1(problem_count); + std::vector host_tensor_C1(problem_count); + std::vector host_tensor_Bias1(problem_count); + std::vector host_tensor_D1(problem_count); + std::vector host_tensor_Z(problem_count); + std::vector host_tensor_ref_D0(problem_count); + std::vector host_tensor_ref_D1(problem_count); + + std::vector ref_A0(problem_count); + std::vector ref_B0(problem_count); + std::vector ref_C0(problem_count); + std::vector ref_Scale0(problem_count); + std::vector ref_Bias0(problem_count); + std::vector ref_B1(problem_count); + std::vector ref_C1(problem_count); + std::vector ref_Bias1(problem_count); + std::vector ref_D1(problem_count); + std::vector ref_Z(problem_count); + std::vector ref_ref_D0(problem_count); + std::vector ref_ref_D1(problem_count); + + for (int i = 0; i < problem_count; ++i) { + // + // Allocate the GEMM workspace + // + + auto problem_size_0 = problem_sizes_0[i]; + auto problem_size_1 = problem_sizes_1[i]; + + host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk()); + host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn()); + host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn()); + if (alpha0 == ElementCompute(0)) //per-channel scale + host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()}); + host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()}); + host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn()); + host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn()); + host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn()); + host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn()); + host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()}); + host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn()); + host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn()); + + CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017)); + if (alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013)); + CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + host_tensor_D1.at(i).host_view()); + cutlass::reference::host::TensorFill( + host_tensor_ref_D0.at(i).host_view()); + cutlass::reference::host::TensorFill( + host_tensor_ref_D1.at(i).host_view()); + + host_tensor_A0.at(i).sync_device(); + host_tensor_B0.at(i).sync_device(); + host_tensor_C0.at(i).sync_device(); + if (alpha0 == ElementCompute(0)) //per-channel scale + host_tensor_Scale0.at(i).sync_device(); + host_tensor_Bias0.at(i).sync_device(); + host_tensor_B1.at(i).sync_device(); + host_tensor_C1.at(i).sync_device(); + host_tensor_Bias1.at(i).sync_device(); + host_tensor_D1.at(i).sync_device(); + host_tensor_ref_D0.at(i).sync_device(); + host_tensor_ref_D1.at(i).sync_device(); + + ref_A0.at(i) = (host_tensor_A0.at(i).device_ref()); + ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());; + ref_C0.at(i) = (host_tensor_C0.at(i).device_ref()); + if (alpha0 == ElementCompute(0)) //per-channel scale + ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref()); + ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref()); + ref_B1.at(i) = (host_tensor_B1.at(i).device_ref()); + ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)}; + ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref()); + ref_D1.at(i) = (host_tensor_D1.at(i).device_ref()); + ref_Z.at(i) = (host_tensor_Z.at(i).device_ref()); + ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref()); + ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref()); + } + + // + // Initialize the GEMM operator + // + + cutlass::DeviceAllocation device_ref_A0(problem_count); + device_ref_A0.copy_from_host(ref_A0.data()); + cutlass::DeviceAllocation device_ref_B0(problem_count); + device_ref_B0.copy_from_host(ref_B0.data()); + cutlass::DeviceAllocation device_ref_C0(problem_count); + device_ref_C0.copy_from_host(ref_C0.data()); + cutlass::DeviceAllocation device_ref_Scale0(problem_count); + device_ref_Scale0.copy_from_host(ref_Scale0.data()); + cutlass::DeviceAllocation device_ref_Bias0(problem_count); + device_ref_Bias0.copy_from_host(ref_Bias0.data()); + cutlass::DeviceAllocation device_ref_B1(problem_count); + device_ref_B1.copy_from_host(ref_B1.data()); + cutlass::DeviceAllocation device_ref_C1(problem_count); + device_ref_C1.copy_from_host(ref_C1.data()); + cutlass::DeviceAllocation device_ref_Bias1(problem_count); + device_ref_Bias1.copy_from_host(ref_Bias1.data()); + cutlass::DeviceAllocation device_ref_D1(problem_count); + device_ref_D1.copy_from_host(ref_D1.data()); + + cutlass::DeviceAllocation device_problem_sizes_0(problem_count); + device_problem_sizes_0.copy_from_host(problem_sizes_0.data()); + cutlass::DeviceAllocation device_problem_sizes_1(problem_count); + device_problem_sizes_1.copy_from_host(problem_sizes_1.data()); + + B2bGemm b2b_gemm_op; + + int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count); + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return false; + } + + typename B2bGemm::Arguments arguments{ + problem_count, + device_problem_sizes_0.get(), + device_problem_sizes_1.get(), + device_ref_A0.get(), + device_ref_B0.get(), + device_ref_C0.get(), + device_ref_Scale0.get(), + device_ref_Bias0.get(), + device_ref_B1.get(), + device_ref_C1.get(), + device_ref_D1.get(), + {alpha0, beta0}, + {alpha1, beta1}, + threadblock_count + }; + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.M = problem_size_1.M\n" + << " problem_size_0.N = problem_size_1.K\n" + << " ThreadblockShape0::kN = problem_size_0.N\n" + << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; + } + + status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + for (int i = 0; i < problem_count; ++i) { + host_tensor_D1.at(i).sync_host();; + + // + // Verify + // + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, + ElementAccumulator> + reference_gemm_1; + + auto problem_size_0 = problem_sizes_0[i]; + auto problem_size_1 = problem_sizes_1[i]; + + reference_gemm_0( + problem_size_0, + ElementAccumulator(1), //intermediate alpha=1 + ref_A0.at(i), + ref_B0.at(i), + ElementAccumulator(0), //beta = 0 + ref_Z.at(i), + ref_Z.at(i), + ElementAccumulator(0) + ); + + cutlass::reference::device::TensorScaleBiasGemm< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutC + > ( + problem_size_0, + ref_Z.at(i), + ref_ref_D0.at(i), + alpha0, + ref_Scale0.at(i), + ref_Bias0.at(i) + ); + + if(relu) { + cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + ref_ref_D0.at(i), + ref_B1.at(i), + beta1, + {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)}, + ref_ref_D1.at(i) + ); + if(relu) { + cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view()); + } + cudaDeviceSynchronize(); + host_tensor_ref_D0.at(i).sync_host(); + host_tensor_ref_D1.at(i).sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + host_tensor_ref_D1.at(i).host_view(), + host_tensor_D1.at(i).host_view()); + + CHECK_TRUE(passed); + if (!passed) + { + + std::stringstream fname; + + fname << "error_B2bGemm_device_fused.txt"; + std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "GEMM " << i << " in group\n" + << "A0 =\n" << host_tensor_A0.at(i).host_view() + << "\nB0 =\n" << host_tensor_B0.at(i).host_view() + << "\nC0 =\n" << host_tensor_C0.at(i).host_view() + << "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n" + << "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n" + << "\nB1 =\n" << host_tensor_B1.at(i).host_view() + << "\nC1 =\n" << host_tensor_C1.at(i).host_view() + << "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n" + << "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view() + << "\nComputed =\n" << host_tensor_D1.at(i).host_view(); + + return false; + } + } + return true; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h index 0fbc930b..0820ecec 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h @@ -185,96 +185,7 @@ class B2bGemm { SmemAccumulator >::B2bGemmKernel; - /// Argument structure - struct Arguments { - - // - // Data members - // - - GemmUniversalMode mode; - GemmCoord problem_size_0; - GemmCoord problem_size_1; - TensorRef ref_A0; - TensorRef ref_B0; - TensorRef ref_C0; - TensorRef ref_Scale0; - TensorRef ref_Bias0; - TensorRef ref_B1; - TensorRef ref_C1; - TensorRef ref_D1; - int64_t batch_stride_A0; - int64_t batch_stride_B0; - int64_t batch_stride_B1; - int64_t batch_stride_C1; - int64_t batch_stride_D1; - int64_t batch_stride_Bias0; - int64_t batch_stride_Scale0; - typename EpilogueOutputOp0::Params epilogue0; - typename EpilogueOutputOp1::Params epilogue1; - int batch_count; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) { - - } - - /// Constructs an Arguments structure - CUTLASS_HOST_DEVICE - Arguments( - GemmUniversalMode mode_, - GemmCoord problem_size_0_, - GemmCoord problem_size_1_, - TensorRef ref_A0_, - TensorRef ref_B0_, - TensorRef ref_C0_, - TensorRef ref_Scale0_, - TensorRef ref_Bias0_, - TensorRef ref_B1_, - TensorRef ref_C1_, - TensorRef ref_D1_, - int64_t batch_stride_A0_, - int64_t batch_stride_B0_, - int64_t batch_stride_B1_, - int64_t batch_stride_C1_, - int64_t batch_stride_D1_, - int64_t batch_stride_Bias0_, - int64_t batch_stride_Scale0_, - typename EpilogueOutputOp0::Params epilogue0_ = - typename EpilogueOutputOp0::Params(), - typename EpilogueOutputOp1::Params epilogue1_ = - typename EpilogueOutputOp1::Params(), - int batch_count_ = 1 - ): - mode(mode_), - problem_size_0(problem_size_0_), - problem_size_1(problem_size_1_), - ref_A0(ref_A0_), - ref_B0(ref_B0_), - ref_C0(ref_C0_), - ref_Scale0(ref_Scale0_), - ref_Bias0(ref_Bias0_), - ref_B1(ref_B1_), - ref_C1(ref_C1_), - ref_D1(ref_D1_), - batch_stride_A0(batch_stride_A0_), - batch_stride_B0(batch_stride_B0_), - batch_stride_B1(batch_stride_B1_), - batch_stride_C1(batch_stride_C1_), - batch_stride_D1(batch_stride_D1_), - batch_stride_Bias0(batch_stride_Bias0_), - batch_stride_Scale0(batch_stride_Scale0_), - epilogue0(epilogue0_), - epilogue1(epilogue1_), - batch_count(batch_count_) { - - } - }; + using Arguments = typename B2bGemmKernel::Arguments; private: diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu new file mode 100644 index 00000000..4abaee51 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu @@ -0,0 +1,297 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident +*/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/base_grouped.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "device/b2b_gemm.h" +#include "kernel/default_b2b_gemm.h" +#include "threadblock/grouped_threadblock_swizzle.h" +#include "b2b_grouped_gemm_run.h" +#include "test_run.h" + +//////////////////////////////////////////////////////////////////////////////// + +std::vector gemm_f16_sm80_problem_sizes_0; +std::vector gemm_f16_sm80_problem_sizes_1; + +// Constraints: +// 1. Warp shape N must equal thread block shape N +// 2. Problem size N must equal thread block shape N +using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; +using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; +using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; +using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + int alignment = 8; + + std::vector problem_sizes0; + std::vector problem_sizes1; + + int problem_count; + bool verbose; + + // + // Methods + // + + Options(): + help(false), + error(false), + reference_check(true), + problem_count(15), + verbose(false) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("problems", problem_count, 15); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("verbose", verbose, false); + + randomize_problems(cmd); + } + + void randomize_problems(cutlass::CommandLine &cmd) { + + // + // For now, randomly choose the problem sizes. + // + + int cmd_line_m = -1; + int cmd_line_k = -1; + + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes0.reserve(problem_count); + problem_sizes1.reserve(problem_count); + + for (int i = 0; i < problem_count; ++i) { + + int m = cmd_line_m; + int k = cmd_line_k; + + if (m < 1) { + m = alignment * ((rand() % 256) + 1); + } + + if (k < 1) { + k = alignment * ((rand() % 256) + 1); + } + + cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k); + cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN); + + problem_sizes0.push_back(problem0); + problem_sizes1.push_back(problem1); + } + + if (verbose) { + print_problem_sizes(); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n" + << " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n" + << " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n" + << " back-to-back GEMMs are subject to.s" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --problems= Number of individual GEMM problems (default: --problems=15)\n" + << " --m= Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n" + << " --verbose= If true, prints problem sizes.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a grouped B2b GEMM with 10 random problem sizes\n" + << "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n"; + + return out; + } + + void print_problem_sizes() { + std::cout << std::endl; + std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl; + for (int i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i); + cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i); + std::cout << "Problem " << i + << "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k() + << "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k() + << std::endl; + } + } +}; + +bool run_fused_grouped_gemm_f16_sm80_rf_res() { + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using EpilogueOutputOp0 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + InstructionShape::kM * InstructionShape::kN / 32, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >; + + using EpilogueOutputOp1 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >; + + using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle< + ThreadblockShape0, + cutlass::layout::RowMajor // LayoutC + >; + + const int kAlignment = 128 / cutlass::sizeof_bits::value; + const int kStages = 3; + using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm< + cutlass::half_t, + cutlass::layout::RowMajor, + kAlignment, + cutlass::half_t, + cutlass::layout::ColumnMajor, + kAlignment, + cutlass::half_t, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + GroupedThreadblockSwizzle, + kStages, + cutlass::arch::OpMultiplyAdd + >::B2bGemmKernel; + + using B2bGemm = cutlass::gemm::device::BaseGrouped; + + B2bFusedGroupedGemmRun fusedGemm; + + std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n"; + bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1); + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +int main(int argc, char const **args) { + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0; + gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1; + + std::vectorfuncs = { + &run_fused_grouped_gemm_f16_sm80_rf_res + }; + + return testRun(80, funcs, "grouped gemm f16 RF residency"); +} + + + + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h index f4794c57..a6d2a8a1 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -40,12 +40,60 @@ #include "cutlass/matrix_coord.h" #include "cutlass/semaphore.h" +#include "kernel/b2b_gemm_grouped_problem_visitor.h" +#include "threadblock/grouped_threadblock_swizzle.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace kernel { +namespace detail { + +/// Utility struct for returning the type of the problem visitor used by the swizzling function, +/// if it is a grouped swizzling function, or a default visitor. This is used only for defining +/// the parameters of the problem visitor used in GroupedParams. +template < + typename B2bMma_, + typename ThreadblockSwizzle_, + typename Enable = void +> +struct ProblemVisitorOrDefault; + +/// Return a generic problem visitor for GEMM problems +template < + typename B2bMma_, + typename ThreadblockSwizzle_ +> +struct ProblemVisitorOrDefault::value + >::type> { + using value = B2bGemmGroupedProblemVisitor::value>; +}; + +/// Return the problem visitor specified by the swizzling function +template < + typename B2bMma_, + typename ThreadblockSwizzle_ +> +struct ProblemVisitorOrDefault::value + >::type> { + using value = typename ThreadblockSwizzle_::ProblemVisitor; +}; + +} // namespace detail + ///////////////////////////////////////////////////////////////////////////////////////////////// template < @@ -72,10 +120,169 @@ struct B2bGemm { using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element; + /// Data types needed for higher-level containers. In some cases, a single type must be exposed + /// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from + /// the second GEMM (other than for ElementA/ElementB) + using ElementA = typename B2bMma::IteratorA0::Element; + using LayoutA = typename B2bMma::IteratorA0::Layout; + using ElementB = typename B2bMma::IteratorB0::Element; + using LayoutB = typename B2bMma::IteratorB0::Layout; + + static ComplexTransform const kTransformA = B2bMma::kTransformA; + static ComplexTransform const kTransformB = B2bMma::kTransformB; + using Operator = typename B2bMma::Operator0; + + using OperatorClass = typename Operator::OperatorClass; + using ThreadblockShape = typename B2bMma::Shape0; + using WarpShape = typename Operator::Shape; + using InstructionShape = typename Operator::InstructionShape; + using ArchTag = typename B2bMma::ArchTag; + + static int const kStages = B2bMma::kStages; + static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + using Mma = B2bMma; + using EpilogueOutputOp = OutputOp1; + /// Warp count (concept: GemmShape) using WarpCount0 = typename B2bMma::WarpCount0; static int const kThreadCount = 32 * WarpCount0::kCount; + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size_0; + GemmCoord problem_size_1; + typename B2bMma::IteratorA0::TensorRef ref_A0; + typename B2bMma::IteratorB0::TensorRef ref_B0; + typename Epilogue::OutputTileIterator::TensorRef ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0; + typename B2bMma::IteratorB1::TensorRef ref_B1; + typename Epilogue::OutputTileIterator::TensorRef ref_C1; + typename Epilogue::OutputTileIterator::TensorRef ref_D1; + int64_t batch_stride_A0; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C1; + int64_t batch_stride_D1; + int64_t batch_stride_Bias0; + int64_t batch_stride_Scale0; + typename OutputOp0::Params epilogue0; + typename OutputOp1::Params epilogue1; + int batch_count; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {} + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_0_, + GemmCoord problem_size_1_, + typename B2bMma::IteratorA0::TensorRef ref_A0_, + typename B2bMma::IteratorB0::TensorRef ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef ref_D1_, + int64_t batch_stride_A0_, + int64_t batch_stride_B0_, + int64_t batch_stride_B1_, + int64_t batch_stride_C1_, + int64_t batch_stride_D1_, + int64_t batch_stride_Bias0_, + int64_t batch_stride_Scale0_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int batch_count_ = 1 + ): + mode(mode_), + problem_size_0(problem_size_0_), + problem_size_1(problem_size_1_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), + ref_Bias0(ref_Bias0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + batch_stride_A0(batch_stride_A0_), + batch_stride_B0(batch_stride_B0_), + batch_stride_B1(batch_stride_B1_), + batch_stride_C1(batch_stride_C1_), + batch_stride_D1(batch_stride_D1_), + batch_stride_Bias0(batch_stride_Bias0_), + batch_stride_Scale0(batch_stride_Scale0_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + batch_count(batch_count_) { + } + }; + + // Arguments structure for grouped B2B problems + struct GroupedArguments { + GemmCoord* problem_size_0; + GemmCoord* problem_size_1; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problmes in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params epilogue0; + typename OutputOp1::Params epilogue1; + + int problem_count; + int threadblock_count; + GemmCoord* host_problem_sizes; + + CUTLASS_HOST_DEVICE + GroupedArguments( + int problem_count, + GemmCoord* problem_size_0_, + GemmCoord* problem_size_1_, + typename B2bMma::IteratorA0::TensorRef* ref_A0_, + typename B2bMma::IteratorB0::TensorRef* ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef* ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_D1_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int threadblock_count = 0 + ) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_), + ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_), + ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_), + problem_count(problem_count), + threadblock_count(threadblock_count) + {} + }; + /// Parameters structure struct Params { cutlass::gemm::GemmUniversalMode mode; @@ -149,7 +356,7 @@ struct B2bGemm { problem_size_0(problem_size_0), problem_size_1(problem_size_1), grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)), params_A0(ref_A0.layout()), ref_A0(ref_A0), params_B0(ref_B0.layout()), @@ -185,6 +392,81 @@ struct B2bGemm { } }; + struct GroupedParams { + cutlass::gemm::GemmCoord* problem_size_0; + cutlass::gemm::GemmCoord* problem_size_1; + cutlass::gemm::GemmCoord* grid_tiled_shape; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problmes in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params output_op_0; + typename OutputOp1::Params output_op_1; + + using ProblemVisitor = typename detail::ProblemVisitorOrDefault::value; + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int* workspace; + + CUTLASS_HOST_DEVICE + GroupedParams() {} + + CUTLASS_HOST_DEVICE + GroupedParams( + GroupedArguments const &args, + void *workspace = nullptr, + int tile_count = 0 + ) : + problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1), + ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0), + ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1), + output_op_0(args.epilogue0), output_op_1(args.epilogue1), + problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + workspace(reinterpret_cast(workspace)) {} + + CUTLASS_HOST_DEVICE + void transpose() { + // Only row-major outputs are currently supported, so no transpose is performed + } + + /// Returns non-grouped paramaters to be used as input to the kernel-level + /// operator for the problem indicated by problem_visitor. + CUTLASS_HOST_DEVICE + Params to_single_params(const ProblemVisitor& problem_visitor) const { + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + int32_t idx = problem_visitor.problem_index(); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1); + + return Params( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size0, + problem_size1, + grid_shape, + ref_A0[idx], + ref_B0[idx], + ref_C0[idx], + ref_Scale0[idx], + ref_Bias0[idx], + ref_B1[idx], + ref_C1[idx], + ref_D1[idx], + 0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported + output_op_0, + output_op_1, + workspace + ); + } + }; + /// Shared memory storage structure union SharedStorage { typename B2bMma::B2bMmaSharedStorage main_loop; @@ -266,9 +548,13 @@ struct B2bGemm { /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); @@ -391,14 +677,17 @@ struct B2bGemm { ) ); - - // // Main loop // OutputOp0 output_op_0(params.output_op_0); + if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle::value) { + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + } + // Construct thread-scoped matrix multiply B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n()); diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h new file mode 100644 index 00000000..a8eafaad --- /dev/null +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped B2b GEMMs +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor< + detail::GemmGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + static bool const kTransposed = Transposed; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), + problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + /// Convert the B2b-GEMM-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams(// Set problem_sizes as problem_sizes0 because these determine + // shape of the grid used in the non-grouped B2b GEMM + problem_sizes0, + problem_count, + workspace, + tile_count); + } + + }; + + // + // Methods + // + CUTLASS_DEVICE + B2bGemmGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base ( + params_.to_base(), + shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) + {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h index 3f54e1da..3136387c 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h @@ -63,7 +63,9 @@ #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "kernel/b2b_gemm.h" +#include "kernel/grouped.h" #include "threadblock/default_b2b_mma.h" +#include "threadblock/grouped_threadblock_swizzle.h" //////////////////////////////////////////////////////////////////////////////// @@ -73,6 +75,9 @@ namespace kernel { //////////////////////////////////////////////////////////////////////////////// +template +using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle; + template < /// Element type for A matrix operand typename ElementA_, @@ -117,7 +122,9 @@ template < /// Operation performed by GEMM typename Operator, /// Stage accumulator in shared memory - bool SmemAccumulator = false + bool SmemAccumulator = false, + /// Whether or not the operation is grouped + typename Enable = void > struct DefaultB2bGemm; @@ -166,7 +173,7 @@ struct DefaultB2bGemm { + Operator, false, typename platform::enable_if::value>::type> { /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, @@ -186,6 +193,71 @@ struct DefaultB2bGemm; }; +/// Partial specialization for Ampere Architecture with grouped operation +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm::value>::type> { + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using UnderlyingB2bGemmKernel = kernel::B2bGemm; + + using B2bGemmKernel = kernel::GroupedKernel; +}; + //////////////////////////////////////////////////////////////////////////////// @@ -242,7 +314,9 @@ struct DefaultB2bGemm< EpilogueOutputOp1, ThreadblockSwizzle, 2, - Operator + Operator, + false, + typename platform::enable_if::value>::type > { /// Define the threadblock-scoped matrix multiply-accumulate @@ -324,7 +398,8 @@ struct DefaultB2bGemm< arch::OpClassTensorOp, arch::Sm80, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, - ThreadblockSwizzle, Stages, Operator> { + ThreadblockSwizzle, Stages, + Operator, false, typename platform::enable_if::value>::type> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -393,7 +468,8 @@ struct DefaultB2bGemm, int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, - ThreadblockSwizzle, 2, Operator> { + ThreadblockSwizzle, 2, Operator, false, + typename platform::enable_if::value>::type> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; diff --git a/examples/13_two_tensor_op_fusion/kernel/grouped.h b/examples/13_two_tensor_op_fusion/kernel/grouped.h new file mode 100644 index 00000000..7b6c9504 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/kernel/grouped.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief High-level interface for running a grouped version of a CUTLASS kernel +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// High-level interface for running a grouped version of a CUTLASS kernel +template < + typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate +> +struct GroupedKernel { +public: + + using BaseKernel = BaseKernel_; + using Epilogue = typename BaseKernel::Epilogue; + + /// Types that need to be exported to work properly with device::BaseGrouped + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + using Mma = typename BaseKernel::Mma; + + using Arguments = typename BaseKernel::GroupedArguments; + using Params = typename BaseKernel::GroupedParams; + using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor; + + static int const kThreadCount = BaseKernel::kThreadCount; + + /// Shared memory storage structure + struct SharedStorage { + typename BaseKernel::SharedStorage kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GroupedKernel() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + /// Executes a kernel-level GEMM in a loop + CUTLASS_DEVICE + void operator()(Params ¶ms, SharedStorage &shared_storage) { + + ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + if (ProblemVisitor::kTransposed) { + params.transpose(); + } + + BaseKernel mma; + + // Outer 'persistent' loop to iterate over tiles + while (swizzle.problem_visitor.next_tile()) { + + typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor); + mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle); + + // Next tile + swizzle.problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h index 3e6e05bb..a2eea528 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h @@ -119,8 +119,10 @@ public: using Shape0 = Shape0_; ///< Iterates over tiles of A operand in global memory using IteratorA0 = IteratorA0_; + using IteratorA = IteratorA0; ///< Iterates over tiles of B operand in global memory using IteratorB0 = IteratorB0_; + using IteratorB = IteratorB0; ///< Policy describing tuning details using Policy0 = Policy0_; @@ -139,6 +141,10 @@ public: using IteratorB1 = IteratorB1_; ///< Policy describing tuning details using Policy1 = Policy1_; + + ///< Export Policy0 as the threadblock-level Mma's policy + using Policy = Policy0; + using Shape = Shape0; using SmemIteratorB1 = SmemIteratorB1_; @@ -188,6 +194,10 @@ public: /// Complex transform on B operand static ComplexTransform const kTransformB1 = Operator1::kTransformB; + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + /// Internal structure exposed for introspection. struct Detail { diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h index 55252edb..f2959bea 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h @@ -121,8 +121,10 @@ public: using Shape0 = Shape0_; ///< Iterates over tiles of A operand in global memory using IteratorA0 = IteratorA0_; + using IteratorA = IteratorA0; ///< Iterates over tiles of B operand in global memory using IteratorB0 = IteratorB0_; + using IteratorB = IteratorB0; ///< Iterates over tiles of the scale and bias vectors in global memory using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Policy describing tuning details @@ -141,6 +143,10 @@ public: ///< Policy describing tuning details using Policy1 = Policy1_; + ///< Export Policy0 as the threadblock-level Mma's policy + using Policy = Policy0; + using Shape = Shape0; + using SmemIteratorB1 = SmemIteratorB1_; using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory @@ -194,6 +200,10 @@ public: /// Complex transform on B operand static ComplexTransform const kTransformB1 = Operator1::kTransformB; + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + /// Internal structure exposed for introspection. struct Detail { diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h index 7afa503a..5240d30a 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h @@ -126,7 +126,9 @@ public: using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; using Policy0 = Policy0_; ///< Policy describing tuning details using SmemIteratorA0 = SmemIteratorA0_; @@ -139,6 +141,8 @@ public: FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory using Policy1 = Policy1_; ///< Policy describing tuning details + using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; using SmemIteratorB1 = SmemIteratorB1_; @@ -195,6 +199,10 @@ public: /// Complex transform on B1 operand static ComplexTransform const kTransformB1 = Operator1::kTransformB; + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h index b78892e1..50cfc207 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -128,7 +128,9 @@ public: using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory using Policy0 = Policy0_; ///< Policy0 describing tuning details @@ -141,6 +143,8 @@ public: using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory using Policy1 = Policy1_; ///< Policy1 describing tuning details + using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; using SmemIteratorB1 = SmemIteratorB1_; using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory @@ -192,6 +196,10 @@ public: /// Complex transform on B1 operand static ComplexTransform const kTransformB1 = Operator1::kTransformB; + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); diff --git a/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h b/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h new file mode 100644 index 00000000..cb409157 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements several threadblock-swizzling functions for grouped kernels +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "kernel/b2b_gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +struct GroupedThreadblockSwizzleBase {}; + +/// Helper for determining if a swizzling function is specialized for grouped operation +template +struct IsGroupedSwizzle { + static bool const value = cutlass::platform::is_base_of::value; +}; + +} // namespace detail + +/// Swizzling function for grouped kernels +template +struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase { + + using ProblemVisitor = ProblemVisitor_; + ProblemVisitor problem_visitor; + + CUTLASS_HOST_DEVICE + GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params, + typename ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : problem_visitor(params, shared_storage, block_idx) {} + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int /*log_tile*/) const { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + return GemmCoord(int(threadblock_idx / grid_shape.n()), + int(threadblock_idx % grid_shape.n()), + 0); + } + + /// Dummy method to satisfy API for threadblock swizzling functions + CUTLASS_HOST_DEVICE + static int get_log_tile(GemmCoord /*tiled_shape*/) { + return 0; + } +}; + +template < + typename ThreadblockShape, + typename LayoutC, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + int PrefetchTileCount = 128, + int ThreadCount = PrefetchTileCount> +struct GemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle< + cutlass::gemm::kernel::GemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + platform::is_same::value + > + > { + using Base = GroupedThreadblockSwizzle::value>>; + + CUTLASS_HOST_DEVICE + GemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params, + typename Base::ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : Base(params, shared_storage, block_idx) {} +}; + +template < + typename ThreadblockShape, + typename LayoutC, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + int PrefetchTileCount = 128, + int ThreadCount = PrefetchTileCount> +struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle< + cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + platform::is_same::value + > + > { + using Base = GroupedThreadblockSwizzle::value>>; + + CUTLASS_HOST_DEVICE + B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params, + typename Base::ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : Base(params, shared_storage, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/include/cutlass/gemm/device/base_grouped.h b/include/cutlass/gemm/device/base_grouped.h index f3094e90..7e83b03c 100644 --- a/include/cutlass/gemm/device/base_grouped.h +++ b/include/cutlass/gemm/device/base_grouped.h @@ -290,7 +290,6 @@ public: int available_sm_count=-1) { // Determine the number of blocks that would be launched to fill up a single // wave on the GPU with each SM having maximum occupancy. - cudaDeviceProp properties; int device_idx; cudaError_t result = cudaGetDevice(&device_idx); if (result != cudaSuccess) { diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/include/cutlass/gemm/threadblock/threadblock_swizzle.h index 48c1737c..34b68988 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle.h @@ -114,7 +114,7 @@ struct GemmIdentityThreadblockSwizzle { /// Calculates optimal swizzle width CUTLASS_HOST_DEVICE - int get_log_tile(GemmCoord tiled_shape) const { + static int get_log_tile(GemmCoord tiled_shape) { auto n = tiled_shape.n(); // Thresholds picked so that it doesn't cause too many no-op CTAs if (N >= 8 && n >= 6) @@ -187,7 +187,7 @@ struct GemmHorizontalThreadblockSwizzle { /// Calculates optimal swizzle width CUTLASS_HOST_DEVICE - int get_log_tile(GemmCoord tiled_shape) const { + static int get_log_tile(GemmCoord tiled_shape) { return 0; } @@ -228,7 +228,7 @@ struct GemmBatchedIdentityThreadblockSwizzle { /// Calculates optimal swizzle width CUTLASS_HOST_DEVICE - int get_log_tile(GemmCoord tiled_shape) const { + static int get_log_tile(GemmCoord tiled_shape) { return 0; } @@ -284,7 +284,7 @@ struct GemmSplitKIdentityThreadblockSwizzle { /// Calculates optimal swizzle width CUTLASS_HOST_DEVICE - int get_log_tile(GemmCoord tiled_shape) const { + static int get_log_tile(GemmCoord tiled_shape) { auto n = tiled_shape.n(); // Thresholds picked so that it doesn't cause too many no-op CTAs if (N >= 8 && n >= 6) @@ -361,7 +361,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle { /// Calculates optimal swizzle width CUTLASS_HOST_DEVICE - int get_log_tile(GemmCoord tiled_shape) const { + static int get_log_tile(GemmCoord tiled_shape) { return 0; } @@ -412,7 +412,7 @@ struct GemvBatchedStridedThreadblockDefaultSwizzle { /// Calculates optimal swizzle width CUTLASS_HOST_DEVICE - int get_log_tile(GemmCoord tiled_shape) const { + static int get_log_tile(GemmCoord tiled_shape) { return 0; }