Add grouped b2b GEMM (#970)

This commit is contained in:
Jack Kosaian 2023-06-05 17:16:57 -04:00 committed by GitHub
parent fde824af21
commit 87349d3496
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1644 additions and 107 deletions

View File

@ -64,6 +64,7 @@ endforeach()
foreach(FUSION_GEMM_EXAMPLE
fused_two_gemms_f16_sm75_rf
fused_two_gemms_f16_sm75_shmem
fused_two_gemms_grouped_f16_sm80_rf
fused_two_gemms_f16_sm80_rf
fused_two_gemms_f16_sm80_shmem
fused_two_gemms_s8_sm75_rf

View File

@ -0,0 +1,450 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Containers for running grouped back-to-back GEMMs
*/
#pragma once
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/util/device_memory.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
#include "helper.h"
#define CHECK_GT(val1, val2) \
if((val1) <= (val2)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
#define CHECK_TRUE(val) \
if(!(val)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
////////////////////////////////////////////////////////////////////////////////
template <typename B2bGemm_>
struct B2bFusedGroupedGemmRun
{
using B2bGemm = B2bGemm_;
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Scale;
cutlass::Distribution::Kind init_Bias;
uint64_t seed;
//
// Methods
//
B2bFusedGroupedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_),
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else {
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
int problem_count = (int)problem_sizes_0.size();
std::vector<HostTensorA> host_tensor_A0(problem_count);
std::vector<HostTensorB> host_tensor_B0(problem_count);
std::vector<HostTensorC> host_tensor_C0(problem_count);
std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
std::vector<HostTensorB> host_tensor_B1(problem_count);
std::vector<HostTensorC> host_tensor_C1(problem_count);
std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
std::vector<HostTensorC> host_tensor_D1(problem_count);
std::vector<HostTensorZ> host_tensor_Z(problem_count);
std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
for (int i = 0; i < problem_count; ++i) {
//
// Allocate the GEMM workspace
//
auto problem_size_0 = problem_sizes_0[i];
auto problem_size_1 = problem_sizes_1[i];
host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
if (alpha0 == ElementCompute(0)) //per-channel scale
host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
if (alpha0 == ElementCompute(0)) //per-channel scale
CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
cutlass::reference::host::TensorFill(
host_tensor_D1.at(i).host_view());
cutlass::reference::host::TensorFill(
host_tensor_ref_D0.at(i).host_view());
cutlass::reference::host::TensorFill(
host_tensor_ref_D1.at(i).host_view());
host_tensor_A0.at(i).sync_device();
host_tensor_B0.at(i).sync_device();
host_tensor_C0.at(i).sync_device();
if (alpha0 == ElementCompute(0)) //per-channel scale
host_tensor_Scale0.at(i).sync_device();
host_tensor_Bias0.at(i).sync_device();
host_tensor_B1.at(i).sync_device();
host_tensor_C1.at(i).sync_device();
host_tensor_Bias1.at(i).sync_device();
host_tensor_D1.at(i).sync_device();
host_tensor_ref_D0.at(i).sync_device();
host_tensor_ref_D1.at(i).sync_device();
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
if (alpha0 == ElementCompute(0)) //per-channel scale
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
}
//
// Initialize the GEMM operator
//
cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
device_ref_A0.copy_from_host(ref_A0.data());
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
device_ref_B0.copy_from_host(ref_B0.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
device_ref_C0.copy_from_host(ref_C0.data());
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
device_ref_Scale0.copy_from_host(ref_Scale0.data());
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
device_ref_Bias0.copy_from_host(ref_Bias0.data());
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
device_ref_B1.copy_from_host(ref_B1.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
device_ref_C1.copy_from_host(ref_C1.data());
cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
device_ref_Bias1.copy_from_host(ref_Bias1.data());
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
device_ref_D1.copy_from_host(ref_D1.data());
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
B2bGemm b2b_gemm_op;
int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
if (!threadblock_count) {
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
return false;
}
typename B2bGemm::Arguments arguments{
problem_count,
device_problem_sizes_0.get(),
device_problem_sizes_1.get(),
device_ref_A0.get(),
device_ref_B0.get(),
device_ref_C0.get(),
device_ref_Scale0.get(),
device_ref_Bias0.get(),
device_ref_B1.get(),
device_ref_C1.get(),
device_ref_D1.get(),
{alpha0, beta0},
{alpha1, beta1},
threadblock_count
};
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
if(status != cutlass::Status::kSuccess) {
std::cout << "Problem sizes not supported.\n"
<< "Requirments:\n"
<< " problem_size_0.M = problem_size_1.M\n"
<< " problem_size_0.N = problem_size_1.K\n"
<< " ThreadblockShape0::kN = problem_size_0.N\n"
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
}
status = b2b_gemm_op.initialize(arguments);
CUTLASS_CHECK(status);
for(int i = 0; i < warm_ups; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
//
// Run the GEMM
//
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
for(int i = 0; i < runs; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaDeviceSynchronize();
float gemmTime;
cudaEventElapsedTime(&gemmTime, start, stop);
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
for (int i = 0; i < problem_count; ++i) {
host_tensor_D1.at(i).sync_host();;
//
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator>
reference_gemm_1;
auto problem_size_0 = problem_sizes_0[i];
auto problem_size_1 = problem_sizes_1[i];
reference_gemm_0(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
ref_A0.at(i),
ref_B0.at(i),
ElementAccumulator(0), //beta = 0
ref_Z.at(i),
ref_Z.at(i),
ElementAccumulator(0)
);
cutlass::reference::device::TensorScaleBiasGemm<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutC
> (
problem_size_0,
ref_Z.at(i),
ref_ref_D0.at(i),
alpha0,
ref_Scale0.at(i),
ref_Bias0.at(i)
);
if(relu) {
cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
ref_ref_D0.at(i),
ref_B1.at(i),
beta1,
{host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
ref_ref_D1.at(i)
);
if(relu) {
cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
}
cudaDeviceSynchronize();
host_tensor_ref_D0.at(i).sync_host();
host_tensor_ref_D1.at(i).sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
host_tensor_ref_D1.at(i).host_view(),
host_tensor_D1.at(i).host_view());
CHECK_TRUE(passed);
if (!passed)
{
std::stringstream fname;
fname << "error_B2bGemm_device_fused.txt";
std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "GEMM " << i << " in group\n"
<< "A0 =\n" << host_tensor_A0.at(i).host_view()
<< "\nB0 =\n" << host_tensor_B0.at(i).host_view()
<< "\nC0 =\n" << host_tensor_C0.at(i).host_view()
<< "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
<< "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
<< "\nB1 =\n" << host_tensor_B1.at(i).host_view()
<< "\nC1 =\n" << host_tensor_C1.at(i).host_view()
<< "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
<< "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
<< "\nComputed =\n" << host_tensor_D1.at(i).host_view();
return false;
}
}
return true;
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -185,96 +185,7 @@ class B2bGemm {
SmemAccumulator
>::B2bGemmKernel;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size_0;
GemmCoord problem_size_1;
TensorRef<ElementA const, LayoutA> ref_A0;
TensorRef<ElementB const, LayoutB> ref_B0;
TensorRef<ElementC const, LayoutC> ref_C0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1;
int batch_count;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
TensorRef<ElementA const, LayoutA> ref_A0_,
TensorRef<ElementB const, LayoutB> ref_B0_,
TensorRef<ElementC const, LayoutC> ref_C0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_,
int64_t batch_stride_A0_,
int64_t batch_stride_B0_,
int64_t batch_stride_B1_,
int64_t batch_stride_C1_,
int64_t batch_stride_D1_,
int64_t batch_stride_Bias0_,
int64_t batch_stride_Scale0_,
typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(),
typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(),
int batch_count_ = 1
):
mode(mode_),
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_),
ref_Bias0(ref_Bias0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
batch_stride_A0(batch_stride_A0_),
batch_stride_B0(batch_stride_B0_),
batch_stride_B1(batch_stride_B1_),
batch_stride_C1(batch_stride_C1_),
batch_stride_D1(batch_stride_D1_),
batch_stride_Bias0(batch_stride_Bias0_),
batch_stride_Scale0(batch_stride_Scale0_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
batch_count(batch_count_) {
}
};
using Arguments = typename B2bGemmKernel::Arguments;
private:

View File

@ -0,0 +1,297 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident
*/
#include <iostream>
#include <vector>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/base_grouped.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "device/b2b_gemm.h"
#include "kernel/default_b2b_gemm.h"
#include "threadblock/grouped_threadblock_swizzle.h"
#include "b2b_grouped_gemm_run.h"
#include "test_run.h"
////////////////////////////////////////////////////////////////////////////////
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_0;
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_1;
// Constraints:
// 1. Warp shape N must equal thread block shape N
// 2. Problem size N must equal thread block shape N
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
// Command line options parsing
struct Options {
bool help;
bool error;
bool reference_check;
int alignment = 8;
std::vector<cutlass::gemm::GemmCoord> problem_sizes0;
std::vector<cutlass::gemm::GemmCoord> problem_sizes1;
int problem_count;
bool verbose;
//
// Methods
//
Options():
help(false),
error(false),
reference_check(true),
problem_count(15),
verbose(false)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("problems", problem_count, 15);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
cmd.get_cmd_line_argument("verbose", verbose, false);
randomize_problems(cmd);
}
void randomize_problems(cutlass::CommandLine &cmd) {
//
// For now, randomly choose the problem sizes.
//
int cmd_line_m = -1;
int cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("k", cmd_line_k);
problem_sizes0.reserve(problem_count);
problem_sizes1.reserve(problem_count);
for (int i = 0; i < problem_count; ++i) {
int m = cmd_line_m;
int k = cmd_line_k;
if (m < 1) {
m = alignment * ((rand() % 256) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 256) + 1);
}
cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k);
cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN);
problem_sizes0.push_back(problem0);
problem_sizes1.push_back(problem1);
}
if (verbose) {
print_problem_sizes();
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
<< " back-to-back GEMMs are subject to.s"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --problems=<int> Number of individual GEMM problems (default: --problems=15)\n"
<< " --m=<int> Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n"
<< " --k=<int> Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n"
<< " --verbose=<bool> If true, prints problem sizes.\n";
out << "\n\nExamples:\n\n"
<< "# Runs a grouped B2b GEMM with 10 random problem sizes\n"
<< "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n";
return out;
}
void print_problem_sizes() {
std::cout << std::endl;
std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl;
for (int i = 0; i < problem_count; ++i) {
cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i);
cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i);
std::cout << "Problem " << i
<< "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k()
<< "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k()
<< std::endl;
}
}
};
bool run_fused_grouped_gemm_f16_sm80_rf_res() {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
InstructionShape::kM * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle<
ThreadblockShape0,
cutlass::layout::RowMajor // LayoutC
>;
const int kAlignment = 128 / cutlass::sizeof_bits<ElementOutput>::value;
const int kStages = 3;
using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm<
cutlass::half_t,
cutlass::layout::RowMajor,
kAlignment,
cutlass::half_t,
cutlass::layout::ColumnMajor,
kAlignment,
cutlass::half_t,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
GroupedThreadblockSwizzle,
kStages,
cutlass::arch::OpMultiplyAdd
>::B2bGemmKernel;
using B2bGemm = cutlass::gemm::device::BaseGrouped<B2bGemmKernel>;
B2bFusedGroupedGemmRun<B2bGemm> fusedGemm;
std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n";
bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return passed;
}
int main(int argc, char const **args) {
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}
gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0;
gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1;
std::vector<bool (*)()>funcs = {
&run_fused_grouped_gemm_f16_sm80_rf_res
};
return testRun(80, funcs, "grouped gemm f16 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -40,12 +40,60 @@
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
#include "threadblock/grouped_threadblock_swizzle.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
namespace detail {
/// Utility struct for returning the type of the problem visitor used by the swizzling function,
/// if it is a grouped swizzling function, or a default visitor. This is used only for defining
/// the parameters of the problem visitor used in GroupedParams.
template <
typename B2bMma_,
typename ThreadblockSwizzle_,
typename Enable = void
>
struct ProblemVisitorOrDefault;
/// Return a generic problem visitor for GEMM problems
template <
typename B2bMma_,
typename ThreadblockSwizzle_
>
struct ProblemVisitorOrDefault<B2bMma_,
ThreadblockSwizzle_,
typename platform::enable_if<
! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
>::type> {
using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
GroupScheduleMode::kDeviceOnly,
128,
128,
platform::is_same<typename B2bMma_::LayoutC,
cutlass::layout::ColumnMajor>::value>;
};
/// Return the problem visitor specified by the swizzling function
template <
typename B2bMma_,
typename ThreadblockSwizzle_
>
struct ProblemVisitorOrDefault<B2bMma_,
ThreadblockSwizzle_,
typename platform::enable_if<
cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
>::type> {
using value = typename ThreadblockSwizzle_::ProblemVisitor;
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
@ -72,10 +120,169 @@ struct B2bGemm {
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
/// Data types needed for higher-level containers. In some cases, a single type must be exposed
/// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
/// the second GEMM (other than for ElementA/ElementB)
using ElementA = typename B2bMma::IteratorA0::Element;
using LayoutA = typename B2bMma::IteratorA0::Layout;
using ElementB = typename B2bMma::IteratorB0::Element;
using LayoutB = typename B2bMma::IteratorB0::Layout;
static ComplexTransform const kTransformA = B2bMma::kTransformA;
static ComplexTransform const kTransformB = B2bMma::kTransformB;
using Operator = typename B2bMma::Operator0;
using OperatorClass = typename Operator::OperatorClass;
using ThreadblockShape = typename B2bMma::Shape0;
using WarpShape = typename Operator::Shape;
using InstructionShape = typename Operator::InstructionShape;
using ArchTag = typename B2bMma::ArchTag;
static int const kStages = B2bMma::kStages;
static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
using Mma = B2bMma;
using EpilogueOutputOp = OutputOp1;
/// Warp count (concept: GemmShape)
using WarpCount0 = typename B2bMma::WarpCount0;
static int const kThreadCount = 32 * WarpCount0::kCount;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size_0;
GemmCoord problem_size_1;
typename B2bMma::IteratorA0::TensorRef ref_A0;
typename B2bMma::IteratorB0::TensorRef ref_B0;
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
typename B2bMma::IteratorB1::TensorRef ref_B1;
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
int batch_count;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
typename B2bMma::IteratorA0::TensorRef ref_A0_,
typename B2bMma::IteratorB0::TensorRef ref_B0_,
typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
typename B2bMma::IteratorB1::TensorRef ref_B1_,
typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
int64_t batch_stride_A0_,
int64_t batch_stride_B0_,
int64_t batch_stride_B1_,
int64_t batch_stride_C1_,
int64_t batch_stride_D1_,
int64_t batch_stride_Bias0_,
int64_t batch_stride_Scale0_,
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
int batch_count_ = 1
):
mode(mode_),
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_),
ref_Bias0(ref_Bias0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
batch_stride_A0(batch_stride_A0_),
batch_stride_B0(batch_stride_B0_),
batch_stride_B1(batch_stride_B1_),
batch_stride_C1(batch_stride_C1_),
batch_stride_D1(batch_stride_D1_),
batch_stride_Bias0(batch_stride_Bias0_),
batch_stride_Scale0(batch_stride_Scale0_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
batch_count(batch_count_) {
}
};
// Arguments structure for grouped B2B problems
struct GroupedArguments {
GemmCoord* problem_size_0;
GemmCoord* problem_size_1;
typename B2bMma::IteratorA0::TensorRef* ref_A0;
typename B2bMma::IteratorB0::TensorRef* ref_B0;
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
typename B2bMma::IteratorB1::TensorRef* ref_B1;
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params epilogue0;
typename OutputOp1::Params epilogue1;
int problem_count;
int threadblock_count;
GemmCoord* host_problem_sizes;
CUTLASS_HOST_DEVICE
GroupedArguments(
int problem_count,
GemmCoord* problem_size_0_,
GemmCoord* problem_size_1_,
typename B2bMma::IteratorA0::TensorRef* ref_A0_,
typename B2bMma::IteratorB0::TensorRef* ref_B0_,
typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
typename B2bMma::IteratorB1::TensorRef* ref_B1_,
typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
int threadblock_count = 0
) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
problem_count(problem_count),
threadblock_count(threadblock_count)
{}
};
/// Parameters structure
struct Params {
cutlass::gemm::GemmUniversalMode mode;
@ -149,7 +356,7 @@ struct B2bGemm {
problem_size_0(problem_size_0),
problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
params_A0(ref_A0.layout()),
ref_A0(ref_A0),
params_B0(ref_B0.layout()),
@ -185,6 +392,81 @@ struct B2bGemm {
}
};
struct GroupedParams {
cutlass::gemm::GemmCoord* problem_size_0;
cutlass::gemm::GemmCoord* problem_size_1;
cutlass::gemm::GemmCoord* grid_tiled_shape;
typename B2bMma::IteratorA0::TensorRef* ref_A0;
typename B2bMma::IteratorB0::TensorRef* ref_B0;
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
typename B2bMma::IteratorB1::TensorRef* ref_B1;
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
// Epilogue params remain constant across all problmes in the group. Thus,
// the parameter here is not a pointer.
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int* workspace;
CUTLASS_HOST_DEVICE
GroupedParams() {}
CUTLASS_HOST_DEVICE
GroupedParams(
GroupedArguments const &args,
void *workspace = nullptr,
int tile_count = 0
) :
problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
output_op_0(args.epilogue0), output_op_1(args.epilogue1),
problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
threadblock_count(args.threadblock_count),
workspace(reinterpret_cast<int*>(workspace)) {}
CUTLASS_HOST_DEVICE
void transpose() {
// Only row-major outputs are currently supported, so no transpose is performed
}
/// Returns non-grouped paramaters to be used as input to the kernel-level
/// operator for the problem indicated by problem_visitor.
CUTLASS_HOST_DEVICE
Params to_single_params(const ProblemVisitor& problem_visitor) const {
GemmCoord problem_size0 = problem_visitor.problem_size0();
GemmCoord problem_size1 = problem_visitor.problem_size1();
int32_t idx = problem_visitor.problem_index();
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
return Params(
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size0,
problem_size1,
grid_shape,
ref_A0[idx],
ref_B0[idx],
ref_C0[idx],
ref_Scale0[idx],
ref_Bias0[idx],
ref_B1[idx],
ref_C1[idx],
ref_D1[idx],
0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
output_op_0,
output_op_1,
workspace
);
}
};
/// Shared memory storage structure
union SharedStorage {
typename B2bMma::B2bMmaSharedStorage main_loop;
@ -266,9 +548,13 @@ struct B2bGemm {
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
run_with_swizzle(params, shared_storage, threadblock_swizzle);
}
/// Executes one GEMM with an externally-provided swizzling function
CUTLASS_DEVICE
void run_with_swizzle(Params const &params, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
@ -391,14 +677,17 @@ struct B2bGemm {
)
);
//
// Main loop
//
OutputOp0 output_op_0(params.output_op_0);
if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
}
// Construct thread-scoped matrix multiply
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());

View File

@ -0,0 +1,157 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Scheduler for grouped B2b GEMMs
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape,
GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount,
int ThreadCount,
bool Transposed = false>
struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount> {
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using BaseParams = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
static bool const kTransposed = Transposed;
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
struct Params {
cutlass::gemm::GemmCoord const *problem_sizes0;
cutlass::gemm::GemmCoord const *problem_sizes1;
int32_t problem_count;
void const *workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
problem_count(0), workspace(nullptr), tile_count(0) { }
/// Ctor
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const *problem_sizes0,
cutlass::gemm::GemmCoord const *problem_sizes1,
int32_t problem_count,
void const *workspace = nullptr,
int32_t tile_count = 0
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
workspace(workspace),
tile_count(tile_count)
{}
/// Convert the B2b-GEMM-specific parameters to those used by the base class
CUTLASS_HOST_DEVICE
BaseParams to_base() const {
return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
// shape of the grid used in the non-grouped B2b GEMM
problem_sizes0,
problem_count,
workspace,
tile_count);
}
};
//
// Methods
//
CUTLASS_DEVICE
B2bGemmGroupedProblemVisitor(
Params const &params_,
SharedStorage &shared_storage_,
int32_t block_idx
): Base (
params_.to_base(),
shared_storage_, block_idx),
problem_sizes0(params_.problem_sizes0),
problem_sizes1(params_.problem_sizes1)
{}
/// Returns the problem size 0 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size0() const {
GemmCoord problem = problem_sizes0[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
/// Returns the problem size 1 for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size1() const {
GemmCoord problem = problem_sizes1[this->problem_idx];
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -63,7 +63,9 @@
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "kernel/b2b_gemm.h"
#include "kernel/grouped.h"
#include "threadblock/default_b2b_mma.h"
#include "threadblock/grouped_threadblock_swizzle.h"
////////////////////////////////////////////////////////////////////////////////
@ -73,6 +75,9 @@ namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <typename T>
using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
template <
/// Element type for A matrix operand
typename ElementA_,
@ -117,7 +122,9 @@ template <
/// Operation performed by GEMM
typename Operator,
/// Stage accumulator in shared memory
bool SmemAccumulator = false
bool SmemAccumulator = false,
/// Whether or not the operation is grouped
typename Enable = void
>
struct DefaultB2bGemm;
@ -166,7 +173,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator> {
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
@ -186,6 +193,71 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
/// Partial specialization for Ampere Architecture with grouped operation
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of A matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp0,
/// Epilogue output operator
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
using Epilogue =
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
};
////////////////////////////////////////////////////////////////////////////////
@ -242,7 +314,9 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
Operator
Operator,
false,
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
> {
/// Define the threadblock-scoped matrix multiply-accumulate
@ -324,7 +398,8 @@ struct DefaultB2bGemm<
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages, Operator> {
ThreadblockSwizzle, Stages,
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -393,7 +468,8 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, Operator> {
ThreadblockSwizzle, 2, Operator, false,
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;

View File

@ -0,0 +1,168 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief High-level interface for running a grouped version of a CUTLASS kernel
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/complex.h"
#include "cutlass/semaphore.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// High-level interface for running a grouped version of a CUTLASS kernel
template <
typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate
>
struct GroupedKernel {
public:
using BaseKernel = BaseKernel_;
using Epilogue = typename BaseKernel::Epilogue;
/// Types that need to be exported to work properly with device::BaseGrouped
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
using Mma = typename BaseKernel::Mma;
using Arguments = typename BaseKernel::GroupedArguments;
using Params = typename BaseKernel::GroupedParams;
using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor;
static int const kThreadCount = BaseKernel::kThreadCount;
/// Shared memory storage structure
struct SharedStorage {
typename BaseKernel::SharedStorage kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GroupedKernel() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return Status::kSuccess;
}
/// Executes a kernel-level GEMM in a loop
CUTLASS_DEVICE
void operator()(Params &params, SharedStorage &shared_storage) {
ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
if (ProblemVisitor::kTransposed) {
params.transpose();
}
BaseKernel mma;
// Outer 'persistent' loop to iterate over tiles
while (swizzle.problem_visitor.next_tile()) {
typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor);
mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle);
// Next tile
swizzle.problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -119,8 +119,10 @@ public:
using Shape0 = Shape0_;
///< Iterates over tiles of A operand in global memory
using IteratorA0 = IteratorA0_;
using IteratorA = IteratorA0;
///< Iterates over tiles of B operand in global memory
using IteratorB0 = IteratorB0_;
using IteratorB = IteratorB0;
///< Policy describing tuning details
using Policy0 = Policy0_;
@ -139,6 +141,10 @@ public:
using IteratorB1 = IteratorB1_;
///< Policy describing tuning details
using Policy1 = Policy1_;
///< Export Policy0 as the threadblock-level Mma's policy
using Policy = Policy0;
using Shape = Shape0;
using SmemIteratorB1 = SmemIteratorB1_;
@ -188,6 +194,10 @@ public:
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// Internal structure exposed for introspection.
struct Detail {

View File

@ -121,8 +121,10 @@ public:
using Shape0 = Shape0_;
///< Iterates over tiles of A operand in global memory
using IteratorA0 = IteratorA0_;
using IteratorA = IteratorA0;
///< Iterates over tiles of B operand in global memory
using IteratorB0 = IteratorB0_;
using IteratorB = IteratorB0;
///< Iterates over tiles of the scale and bias vectors in global memory
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
///< Policy describing tuning details
@ -141,6 +143,10 @@ public:
///< Policy describing tuning details
using Policy1 = Policy1_;
///< Export Policy0 as the threadblock-level Mma's policy
using Policy = Policy0;
using Shape = Shape0;
using SmemIteratorB1 = SmemIteratorB1_;
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
@ -194,6 +200,10 @@ public:
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// Internal structure exposed for introspection.
struct Detail {

View File

@ -126,7 +126,9 @@ public:
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA0;
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB0;
using Policy0 = Policy0_; ///< Policy describing tuning details
using SmemIteratorA0 = SmemIteratorA0_;
@ -139,6 +141,8 @@ public:
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
using Policy1 = Policy1_; ///< Policy describing tuning details
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
using Shape = Shape1;
using SmemIteratorB1 = SmemIteratorB1_;
@ -195,6 +199,10 @@ public:
/// Complex transform on B1 operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");

View File

@ -128,7 +128,9 @@ public:
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA0;
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB0;
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
using Policy0 = Policy0_; ///< Policy0 describing tuning details
@ -141,6 +143,8 @@ public:
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
using Policy1 = Policy1_; ///< Policy1 describing tuning details
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
using Shape = Shape1;
using SmemIteratorB1 = SmemIteratorB1_;
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
@ -192,6 +196,10 @@ public:
/// Complex transform on B1 operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Complex transform exports needed by higher-level kernels
static ComplexTransform const kTransformA = kTransformA0;
static ComplexTransform const kTransformB = kTransformB0;
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");

View File

@ -0,0 +1,153 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements several threadblock-swizzling functions for grouped kernels
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct GroupedThreadblockSwizzleBase {};
/// Helper for determining if a swizzling function is specialized for grouped operation
template <typename ThreadblockSwizzle>
struct IsGroupedSwizzle {
static bool const value = cutlass::platform::is_base_of<GroupedThreadblockSwizzleBase, ThreadblockSwizzle>::value;
};
} // namespace detail
/// Swizzling function for grouped kernels
template <typename ProblemVisitor_>
struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase {
using ProblemVisitor = ProblemVisitor_;
ProblemVisitor problem_visitor;
CUTLASS_HOST_DEVICE
GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params,
typename ProblemVisitor::SharedStorage& shared_storage,
int block_idx) : problem_visitor(params, shared_storage, block_idx) {}
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
CUTLASS_DEVICE
GemmCoord get_tile_offset(int /*log_tile*/) const {
GemmCoord problem_size = problem_visitor.problem_size();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
return GemmCoord(int(threadblock_idx / grid_shape.n()),
int(threadblock_idx % grid_shape.n()),
0);
}
/// Dummy method to satisfy API for threadblock swizzling functions
CUTLASS_HOST_DEVICE
static int get_log_tile(GemmCoord /*tiled_shape*/) {
return 0;
}
};
template <
typename ThreadblockShape,
typename LayoutC,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
int PrefetchTileCount = 128,
int ThreadCount = PrefetchTileCount>
struct GemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle<
cutlass::gemm::kernel::GemmGroupedProblemVisitor<
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount,
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value
>
> {
using Base = GroupedThreadblockSwizzle<cutlass::gemm::kernel::GemmGroupedProblemVisitor<
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount,
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value>>;
CUTLASS_HOST_DEVICE
GemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params,
typename Base::ProblemVisitor::SharedStorage& shared_storage,
int block_idx) : Base(params, shared_storage, block_idx) {}
};
template <
typename ThreadblockShape,
typename LayoutC,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
int PrefetchTileCount = 128,
int ThreadCount = PrefetchTileCount>
struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle<
cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount,
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value
>
> {
using Base = GroupedThreadblockSwizzle<cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
ThreadblockShape,
GroupScheduleMode_,
PrefetchTileCount,
ThreadCount,
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value>>;
CUTLASS_HOST_DEVICE
B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params,
typename Base::ProblemVisitor::SharedStorage& shared_storage,
int block_idx) : Base(params, shared_storage, block_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -290,7 +290,6 @@ public:
int available_sm_count=-1) {
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
cudaDeviceProp properties;
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {

View File

@ -114,7 +114,7 @@ struct GemmIdentityThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
auto n = tiled_shape.n();
// Thresholds picked so that it doesn't cause too many no-op CTAs
if (N >= 8 && n >= 6)
@ -187,7 +187,7 @@ struct GemmHorizontalThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
return 0;
}
@ -228,7 +228,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
return 0;
}
@ -284,7 +284,7 @@ struct GemmSplitKIdentityThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
auto n = tiled_shape.n();
// Thresholds picked so that it doesn't cause too many no-op CTAs
if (N >= 8 && n >= 6)
@ -361,7 +361,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
return 0;
}
@ -412,7 +412,7 @@ struct GemvBatchedStridedThreadblockDefaultSwizzle {
/// Calculates optimal swizzle width
CUTLASS_HOST_DEVICE
int get_log_tile(GemmCoord tiled_shape) const {
static int get_log_tile(GemmCoord tiled_shape) {
return 0;
}