Add grouped b2b GEMM (#970)
This commit is contained in:
parent
fde824af21
commit
87349d3496
@ -64,6 +64,7 @@ endforeach()
|
||||
foreach(FUSION_GEMM_EXAMPLE
|
||||
fused_two_gemms_f16_sm75_rf
|
||||
fused_two_gemms_f16_sm75_shmem
|
||||
fused_two_gemms_grouped_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_rf
|
||||
fused_two_gemms_f16_sm80_shmem
|
||||
fused_two_gemms_s8_sm75_rf
|
||||
|
450
examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
Normal file
450
examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
Normal file
@ -0,0 +1,450 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Containers for running grouped back-to-back GEMMs
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
if((val1) <= (val2)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||
#define CHECK_TRUE(val) \
|
||||
if(!(val)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename B2bGemm_>
|
||||
struct B2bFusedGroupedGemmRun
|
||||
{
|
||||
|
||||
using B2bGemm = B2bGemm_;
|
||||
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
||||
using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
B2bFusedGroupedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
|
||||
using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
|
||||
using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
|
||||
using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
||||
using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
|
||||
using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
||||
|
||||
int problem_count = (int)problem_sizes_0.size();
|
||||
|
||||
std::vector<HostTensorA> host_tensor_A0(problem_count);
|
||||
std::vector<HostTensorB> host_tensor_B0(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_C0(problem_count);
|
||||
std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
|
||||
std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
|
||||
std::vector<HostTensorB> host_tensor_B1(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_C1(problem_count);
|
||||
std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_D1(problem_count);
|
||||
std::vector<HostTensorZ> host_tensor_Z(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
|
||||
std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
|
||||
|
||||
std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
|
||||
std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
|
||||
std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
|
||||
std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
|
||||
std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
|
||||
std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
|
||||
std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
|
||||
std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
auto problem_size_0 = problem_sizes_0[i];
|
||||
auto problem_size_1 = problem_sizes_1[i];
|
||||
|
||||
host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
|
||||
host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
|
||||
host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
|
||||
host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
|
||||
host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
|
||||
host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
|
||||
host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
|
||||
host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
|
||||
host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
|
||||
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_D1.at(i).host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_ref_D0.at(i).host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
host_tensor_ref_D1.at(i).host_view());
|
||||
|
||||
host_tensor_A0.at(i).sync_device();
|
||||
host_tensor_B0.at(i).sync_device();
|
||||
host_tensor_C0.at(i).sync_device();
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
host_tensor_Scale0.at(i).sync_device();
|
||||
host_tensor_Bias0.at(i).sync_device();
|
||||
host_tensor_B1.at(i).sync_device();
|
||||
host_tensor_C1.at(i).sync_device();
|
||||
host_tensor_Bias1.at(i).sync_device();
|
||||
host_tensor_D1.at(i).sync_device();
|
||||
host_tensor_ref_D0.at(i).sync_device();
|
||||
host_tensor_ref_D1.at(i).sync_device();
|
||||
|
||||
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
|
||||
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
|
||||
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
|
||||
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
|
||||
ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
|
||||
ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
|
||||
ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
|
||||
ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
|
||||
ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
|
||||
ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
|
||||
ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
|
||||
ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
|
||||
}
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
|
||||
device_ref_A0.copy_from_host(ref_A0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
|
||||
device_ref_B0.copy_from_host(ref_B0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
|
||||
device_ref_C0.copy_from_host(ref_C0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
|
||||
device_ref_Scale0.copy_from_host(ref_Scale0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
|
||||
device_ref_Bias0.copy_from_host(ref_Bias0.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
|
||||
device_ref_B1.copy_from_host(ref_B1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
|
||||
device_ref_C1.copy_from_host(ref_C1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
|
||||
device_ref_Bias1.copy_from_host(ref_Bias1.data());
|
||||
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
|
||||
device_ref_D1.copy_from_host(ref_D1.data());
|
||||
|
||||
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
|
||||
device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
|
||||
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
|
||||
device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
|
||||
if (!threadblock_count) {
|
||||
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
problem_count,
|
||||
device_problem_sizes_0.get(),
|
||||
device_problem_sizes_1.get(),
|
||||
device_ref_A0.get(),
|
||||
device_ref_B0.get(),
|
||||
device_ref_C0.get(),
|
||||
device_ref_Scale0.get(),
|
||||
device_ref_Bias0.get(),
|
||||
device_ref_B1.get(),
|
||||
device_ref_C1.get(),
|
||||
device_ref_D1.get(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
threadblock_count
|
||||
};
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||
|
||||
if(status != cutlass::Status::kSuccess) {
|
||||
std::cout << "Problem sizes not supported.\n"
|
||||
<< "Requirments:\n"
|
||||
<< " problem_size_0.M = problem_size_1.M\n"
|
||||
<< " problem_size_0.N = problem_size_1.K\n"
|
||||
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
for(int i = 0; i < warm_ups; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = b2b_gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
host_tensor_D1.at(i).sync_host();;
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator>
|
||||
reference_gemm_1;
|
||||
|
||||
auto problem_size_0 = problem_sizes_0[i];
|
||||
auto problem_size_1 = problem_sizes_1[i];
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
ref_A0.at(i),
|
||||
ref_B0.at(i),
|
||||
ElementAccumulator(0), //beta = 0
|
||||
ref_Z.at(i),
|
||||
ref_Z.at(i),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutC
|
||||
> (
|
||||
problem_size_0,
|
||||
ref_Z.at(i),
|
||||
ref_ref_D0.at(i),
|
||||
alpha0,
|
||||
ref_Scale0.at(i),
|
||||
ref_Bias0.at(i)
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
ref_ref_D0.at(i),
|
||||
ref_B1.at(i),
|
||||
beta1,
|
||||
{host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
ref_ref_D1.at(i)
|
||||
);
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
|
||||
}
|
||||
cudaDeviceSynchronize();
|
||||
host_tensor_ref_D0.at(i).sync_host();
|
||||
host_tensor_ref_D1.at(i).sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
host_tensor_ref_D1.at(i).host_view(),
|
||||
host_tensor_D1.at(i).host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bGemm_device_fused.txt";
|
||||
std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "GEMM " << i << " in group\n"
|
||||
<< "A0 =\n" << host_tensor_A0.at(i).host_view()
|
||||
<< "\nB0 =\n" << host_tensor_B0.at(i).host_view()
|
||||
<< "\nC0 =\n" << host_tensor_C0.at(i).host_view()
|
||||
<< "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
|
||||
<< "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
|
||||
<< "\nB1 =\n" << host_tensor_B1.at(i).host_view()
|
||||
<< "\nC1 =\n" << host_tensor_C1.at(i).host_view()
|
||||
<< "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
|
||||
<< "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
|
||||
<< "\nComputed =\n" << host_tensor_D1.at(i).host_view();
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@ -185,96 +185,7 @@ class B2bGemm {
|
||||
SmemAccumulator
|
||||
>::B2bGemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size_0;
|
||||
GemmCoord problem_size_1;
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
typename EpilogueOutputOp0::Params epilogue0;
|
||||
typename EpilogueOutputOp1::Params epilogue1;
|
||||
int batch_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmUniversalMode mode_,
|
||||
GemmCoord problem_size_0_,
|
||||
GemmCoord problem_size_1_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
int64_t batch_stride_A0_,
|
||||
int64_t batch_stride_B0_,
|
||||
int64_t batch_stride_B1_,
|
||||
int64_t batch_stride_C1_,
|
||||
int64_t batch_stride_D1_,
|
||||
int64_t batch_stride_Bias0_,
|
||||
int64_t batch_stride_Scale0_,
|
||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
||||
typename EpilogueOutputOp0::Params(),
|
||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
||||
typename EpilogueOutputOp1::Params(),
|
||||
int batch_count_ = 1
|
||||
):
|
||||
mode(mode_),
|
||||
problem_size_0(problem_size_0_),
|
||||
problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_),
|
||||
ref_Bias0(ref_Bias0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
batch_stride_A0(batch_stride_A0_),
|
||||
batch_stride_B0(batch_stride_B0_),
|
||||
batch_stride_B1(batch_stride_B1_),
|
||||
batch_stride_C1(batch_stride_C1_),
|
||||
batch_stride_D1(batch_stride_D1_),
|
||||
batch_stride_Bias0(batch_stride_Bias0_),
|
||||
batch_stride_Scale0(batch_stride_Scale0_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
batch_count(batch_count_) {
|
||||
|
||||
}
|
||||
};
|
||||
using Arguments = typename B2bGemmKernel::Arguments;
|
||||
|
||||
private:
|
||||
|
||||
|
@ -0,0 +1,297 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/base_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
#include "b2b_grouped_gemm_run.h"
|
||||
#include "test_run.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_0;
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_1;
|
||||
|
||||
// Constraints:
|
||||
// 1. Warp shape N must equal thread block shape N
|
||||
// 2. Problem size N must equal thread block shape N
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
|
||||
// Command line options parsing
|
||||
struct Options {
|
||||
|
||||
bool help;
|
||||
bool error;
|
||||
bool reference_check;
|
||||
int alignment = 8;
|
||||
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes0;
|
||||
std::vector<cutlass::gemm::GemmCoord> problem_sizes1;
|
||||
|
||||
int problem_count;
|
||||
bool verbose;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Options():
|
||||
help(false),
|
||||
error(false),
|
||||
reference_check(true),
|
||||
problem_count(15),
|
||||
verbose(false)
|
||||
{ }
|
||||
|
||||
// Parses the command line
|
||||
void parse(int argc, char const **args) {
|
||||
cutlass::CommandLine cmd(argc, args);
|
||||
|
||||
if (cmd.check_cmd_line_flag("help")) {
|
||||
help = true;
|
||||
return;
|
||||
}
|
||||
|
||||
cmd.get_cmd_line_argument("problems", problem_count, 15);
|
||||
cmd.get_cmd_line_argument("reference-check", reference_check, true);
|
||||
cmd.get_cmd_line_argument("verbose", verbose, false);
|
||||
|
||||
randomize_problems(cmd);
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
|
||||
//
|
||||
// For now, randomly choose the problem sizes.
|
||||
//
|
||||
|
||||
int cmd_line_m = -1;
|
||||
int cmd_line_k = -1;
|
||||
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
|
||||
problem_sizes0.reserve(problem_count);
|
||||
problem_sizes1.reserve(problem_count);
|
||||
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
|
||||
int m = cmd_line_m;
|
||||
int k = cmd_line_k;
|
||||
|
||||
if (m < 1) {
|
||||
m = alignment * ((rand() % 256) + 1);
|
||||
}
|
||||
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 256) + 1);
|
||||
}
|
||||
|
||||
cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k);
|
||||
cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN);
|
||||
|
||||
problem_sizes0.push_back(problem0);
|
||||
problem_sizes1.push_back(problem1);
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
print_problem_sizes();
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
|
||||
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
|
||||
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
|
||||
<< " back-to-back GEMMs are subject to.s"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement.\n\n"
|
||||
<< " --problems=<int> Number of individual GEMM problems (default: --problems=15)\n"
|
||||
<< " --m=<int> Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n"
|
||||
<< " --k=<int> Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n"
|
||||
<< " --verbose=<bool> If true, prints problem sizes.\n";
|
||||
|
||||
out << "\n\nExamples:\n\n"
|
||||
|
||||
<< "# Runs a grouped B2b GEMM with 10 random problem sizes\n"
|
||||
<< "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void print_problem_sizes() {
|
||||
std::cout << std::endl;
|
||||
std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl;
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i);
|
||||
cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i);
|
||||
std::cout << "Problem " << i
|
||||
<< "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k()
|
||||
<< "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k()
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool run_fused_grouped_gemm_f16_sm80_rf_res() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle<
|
||||
ThreadblockShape0,
|
||||
cutlass::layout::RowMajor // LayoutC
|
||||
>;
|
||||
|
||||
const int kAlignment = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||
const int kStages = 3;
|
||||
using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
kAlignment,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
kAlignment,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
GroupedThreadblockSwizzle,
|
||||
kStages,
|
||||
cutlass::arch::OpMultiplyAdd
|
||||
>::B2bGemmKernel;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::BaseGrouped<B2bGemmKernel>;
|
||||
|
||||
B2bFusedGroupedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main(int argc, char const **args) {
|
||||
|
||||
//
|
||||
// Parse options
|
||||
//
|
||||
|
||||
Options options;
|
||||
|
||||
options.parse(argc, args);
|
||||
|
||||
if (options.help) {
|
||||
options.print_usage(std::cout) << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (options.error) {
|
||||
std::cerr << "Aborting execution." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0;
|
||||
gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1;
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_fused_grouped_gemm_f16_sm80_rf_res
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "grouped gemm f16 RF residency");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@ -40,12 +40,60 @@
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Utility struct for returning the type of the problem visitor used by the swizzling function,
|
||||
/// if it is a grouped swizzling function, or a default visitor. This is used only for defining
|
||||
/// the parameters of the problem visitor used in GroupedParams.
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_,
|
||||
typename Enable = void
|
||||
>
|
||||
struct ProblemVisitorOrDefault;
|
||||
|
||||
/// Return a generic problem visitor for GEMM problems
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_
|
||||
>
|
||||
struct ProblemVisitorOrDefault<B2bMma_,
|
||||
ThreadblockSwizzle_,
|
||||
typename platform::enable_if<
|
||||
! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
||||
>::type> {
|
||||
using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
|
||||
GroupScheduleMode::kDeviceOnly,
|
||||
128,
|
||||
128,
|
||||
platform::is_same<typename B2bMma_::LayoutC,
|
||||
cutlass::layout::ColumnMajor>::value>;
|
||||
};
|
||||
|
||||
/// Return the problem visitor specified by the swizzling function
|
||||
template <
|
||||
typename B2bMma_,
|
||||
typename ThreadblockSwizzle_
|
||||
>
|
||||
struct ProblemVisitorOrDefault<B2bMma_,
|
||||
ThreadblockSwizzle_,
|
||||
typename platform::enable_if<
|
||||
cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
||||
>::type> {
|
||||
using value = typename ThreadblockSwizzle_::ProblemVisitor;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
@ -72,10 +120,169 @@ struct B2bGemm {
|
||||
|
||||
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||
|
||||
/// Data types needed for higher-level containers. In some cases, a single type must be exposed
|
||||
/// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
|
||||
/// the second GEMM (other than for ElementA/ElementB)
|
||||
using ElementA = typename B2bMma::IteratorA0::Element;
|
||||
using LayoutA = typename B2bMma::IteratorA0::Layout;
|
||||
using ElementB = typename B2bMma::IteratorB0::Element;
|
||||
using LayoutB = typename B2bMma::IteratorB0::Layout;
|
||||
|
||||
static ComplexTransform const kTransformA = B2bMma::kTransformA;
|
||||
static ComplexTransform const kTransformB = B2bMma::kTransformB;
|
||||
using Operator = typename B2bMma::Operator0;
|
||||
|
||||
using OperatorClass = typename Operator::OperatorClass;
|
||||
using ThreadblockShape = typename B2bMma::Shape0;
|
||||
using WarpShape = typename Operator::Shape;
|
||||
using InstructionShape = typename Operator::InstructionShape;
|
||||
using ArchTag = typename B2bMma::ArchTag;
|
||||
|
||||
static int const kStages = B2bMma::kStages;
|
||||
static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using Mma = B2bMma;
|
||||
using EpilogueOutputOp = OutputOp1;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||
static int const kThreadCount = 32 * WarpCount0::kCount;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size_0;
|
||||
GemmCoord problem_size_1;
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
typename OutputOp0::Params epilogue0;
|
||||
typename OutputOp1::Params epilogue1;
|
||||
int batch_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmUniversalMode mode_,
|
||||
GemmCoord problem_size_0_,
|
||||
GemmCoord problem_size_1_,
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0_,
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
|
||||
int64_t batch_stride_A0_,
|
||||
int64_t batch_stride_B0_,
|
||||
int64_t batch_stride_B1_,
|
||||
int64_t batch_stride_C1_,
|
||||
int64_t batch_stride_D1_,
|
||||
int64_t batch_stride_Bias0_,
|
||||
int64_t batch_stride_Scale0_,
|
||||
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
||||
int batch_count_ = 1
|
||||
):
|
||||
mode(mode_),
|
||||
problem_size_0(problem_size_0_),
|
||||
problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_),
|
||||
ref_Bias0(ref_Bias0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
batch_stride_A0(batch_stride_A0_),
|
||||
batch_stride_B0(batch_stride_B0_),
|
||||
batch_stride_B1(batch_stride_B1_),
|
||||
batch_stride_C1(batch_stride_C1_),
|
||||
batch_stride_D1(batch_stride_D1_),
|
||||
batch_stride_Bias0(batch_stride_Bias0_),
|
||||
batch_stride_Scale0(batch_stride_Scale0_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
batch_count(batch_count_) {
|
||||
}
|
||||
};
|
||||
|
||||
// Arguments structure for grouped B2B problems
|
||||
struct GroupedArguments {
|
||||
GemmCoord* problem_size_0;
|
||||
GemmCoord* problem_size_1;
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params epilogue0;
|
||||
typename OutputOp1::Params epilogue1;
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedArguments(
|
||||
int problem_count,
|
||||
GemmCoord* problem_size_0_,
|
||||
GemmCoord* problem_size_1_,
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0_,
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
|
||||
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
||||
int threadblock_count = 0
|
||||
) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
|
||||
problem_count(problem_count),
|
||||
threadblock_count(threadblock_count)
|
||||
{}
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmUniversalMode mode;
|
||||
@ -149,7 +356,7 @@ struct B2bGemm {
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
||||
swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
|
||||
params_A0(ref_A0.layout()),
|
||||
ref_A0(ref_A0),
|
||||
params_B0(ref_B0.layout()),
|
||||
@ -185,6 +392,81 @@ struct B2bGemm {
|
||||
}
|
||||
};
|
||||
|
||||
struct GroupedParams {
|
||||
cutlass::gemm::GemmCoord* problem_size_0;
|
||||
cutlass::gemm::GemmCoord* problem_size_1;
|
||||
cutlass::gemm::GemmCoord* grid_tiled_shape;
|
||||
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
||||
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
||||
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||
|
||||
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||
// the parameter here is not a pointer.
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
|
||||
using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
int* workspace;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedParams() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedParams(
|
||||
GroupedArguments const &args,
|
||||
void *workspace = nullptr,
|
||||
int tile_count = 0
|
||||
) :
|
||||
problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
|
||||
ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
|
||||
ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
|
||||
output_op_0(args.epilogue0), output_op_1(args.epilogue1),
|
||||
problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
|
||||
threadblock_count(args.threadblock_count),
|
||||
workspace(reinterpret_cast<int*>(workspace)) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void transpose() {
|
||||
// Only row-major outputs are currently supported, so no transpose is performed
|
||||
}
|
||||
|
||||
/// Returns non-grouped paramaters to be used as input to the kernel-level
|
||||
/// operator for the problem indicated by problem_visitor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params to_single_params(const ProblemVisitor& problem_visitor) const {
|
||||
GemmCoord problem_size0 = problem_visitor.problem_size0();
|
||||
GemmCoord problem_size1 = problem_visitor.problem_size1();
|
||||
int32_t idx = problem_visitor.problem_index();
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
|
||||
|
||||
return Params(
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size0,
|
||||
problem_size1,
|
||||
grid_shape,
|
||||
ref_A0[idx],
|
||||
ref_B0[idx],
|
||||
ref_C0[idx],
|
||||
ref_Scale0[idx],
|
||||
ref_Bias0[idx],
|
||||
ref_B1[idx],
|
||||
ref_C1[idx],
|
||||
ref_D1[idx],
|
||||
0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
|
||||
output_op_0,
|
||||
output_op_1,
|
||||
workspace
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename B2bMma::B2bMmaSharedStorage main_loop;
|
||||
@ -266,9 +548,13 @@ struct B2bGemm {
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
run_with_swizzle(params, shared_storage, threadblock_swizzle);
|
||||
}
|
||||
|
||||
/// Executes one GEMM with an externally-provided swizzling function
|
||||
CUTLASS_DEVICE
|
||||
void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
@ -391,14 +677,17 @@ struct B2bGemm {
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
OutputOp0 output_op_0(params.output_op_0);
|
||||
|
||||
if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
|
||||
|
||||
|
@ -0,0 +1,157 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped B2b GEMMs
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape,
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
|
||||
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount> {
|
||||
|
||||
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||
using BaseParams = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||
int32_t problem_count;
|
||||
void const *workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
|
||||
problem_count(0), workspace(nullptr), tile_count(0) { }
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const *problem_sizes0,
|
||||
cutlass::gemm::GemmCoord const *problem_sizes1,
|
||||
int32_t problem_count,
|
||||
void const *workspace = nullptr,
|
||||
int32_t tile_count = 0
|
||||
):
|
||||
problem_sizes0(problem_sizes0),
|
||||
problem_sizes1(problem_sizes1),
|
||||
problem_count(problem_count),
|
||||
workspace(workspace),
|
||||
tile_count(tile_count)
|
||||
{}
|
||||
|
||||
/// Convert the B2b-GEMM-specific parameters to those used by the base class
|
||||
CUTLASS_HOST_DEVICE
|
||||
BaseParams to_base() const {
|
||||
return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
|
||||
// shape of the grid used in the non-grouped B2b GEMM
|
||||
problem_sizes0,
|
||||
problem_count,
|
||||
workspace,
|
||||
tile_count);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
B2bGemmGroupedProblemVisitor(
|
||||
Params const ¶ms_,
|
||||
SharedStorage &shared_storage_,
|
||||
int32_t block_idx
|
||||
): Base (
|
||||
params_.to_base(),
|
||||
shared_storage_, block_idx),
|
||||
problem_sizes0(params_.problem_sizes0),
|
||||
problem_sizes1(params_.problem_sizes1)
|
||||
{}
|
||||
|
||||
/// Returns the problem size 0 for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size0() const {
|
||||
GemmCoord problem = problem_sizes0[this->problem_idx];
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
/// Returns the problem size 1 for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size1() const {
|
||||
GemmCoord problem = problem_sizes1[this->problem_idx];
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -63,7 +63,9 @@
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "kernel/grouped.h"
|
||||
#include "threadblock/default_b2b_mma.h"
|
||||
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -73,6 +75,9 @@ namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
@ -117,7 +122,9 @@ template <
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Stage accumulator in shared memory
|
||||
bool SmemAccumulator = false
|
||||
bool SmemAccumulator = false,
|
||||
/// Whether or not the operation is grouped
|
||||
typename Enable = void
|
||||
>
|
||||
struct DefaultB2bGemm;
|
||||
|
||||
@ -166,7 +173,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator> {
|
||||
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
@ -186,6 +193,71 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
/// Partial specialization for Ampere Architecture with grouped operation
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
|
||||
using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -242,7 +314,9 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
Operator
|
||||
Operator,
|
||||
false,
|
||||
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
@ -324,7 +398,8 @@ struct DefaultB2bGemm<
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages, Operator> {
|
||||
ThreadblockSwizzle, Stages,
|
||||
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -393,7 +468,8 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, Operator> {
|
||||
ThreadblockSwizzle, 2, Operator, false,
|
||||
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
168
examples/13_two_tensor_op_fusion/kernel/grouped.h
Normal file
168
examples/13_two_tensor_op_fusion/kernel/grouped.h
Normal file
@ -0,0 +1,168 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief High-level interface for running a grouped version of a CUTLASS kernel
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// High-level interface for running a grouped version of a CUTLASS kernel
|
||||
template <
|
||||
typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate
|
||||
>
|
||||
struct GroupedKernel {
|
||||
public:
|
||||
|
||||
using BaseKernel = BaseKernel_;
|
||||
using Epilogue = typename BaseKernel::Epilogue;
|
||||
|
||||
/// Types that need to be exported to work properly with device::BaseGrouped
|
||||
using ElementA = typename BaseKernel::ElementA;
|
||||
using LayoutA = typename BaseKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||
|
||||
using ElementB = typename BaseKernel::ElementB;
|
||||
using LayoutB = typename BaseKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename BaseKernel::ElementC;
|
||||
using LayoutC = typename BaseKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
|
||||
|
||||
using Operator = typename BaseKernel::Operator;
|
||||
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||
using WarpShape = typename BaseKernel::WarpShape;
|
||||
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||
static int const kStages = BaseKernel::Mma::kStages;
|
||||
|
||||
using Mma = typename BaseKernel::Mma;
|
||||
|
||||
using Arguments = typename BaseKernel::GroupedArguments;
|
||||
using Params = typename BaseKernel::GroupedParams;
|
||||
using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor;
|
||||
|
||||
static int const kThreadCount = BaseKernel::kThreadCount;
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage {
|
||||
typename BaseKernel::SharedStorage kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GroupedKernel() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes a kernel-level GEMM in a loop
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
if (ProblemVisitor::kTransposed) {
|
||||
params.transpose();
|
||||
}
|
||||
|
||||
BaseKernel mma;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (swizzle.problem_visitor.next_tile()) {
|
||||
|
||||
typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor);
|
||||
mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle);
|
||||
|
||||
// Next tile
|
||||
swizzle.problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -119,8 +119,10 @@ public:
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
using IteratorA = IteratorA0;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
using IteratorB = IteratorB0;
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
|
||||
@ -139,6 +141,10 @@ public:
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
///< Export Policy0 as the threadblock-level Mma's policy
|
||||
using Policy = Policy0;
|
||||
using Shape = Shape0;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
@ -188,6 +194,10 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
|
@ -121,8 +121,10 @@ public:
|
||||
using Shape0 = Shape0_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA0 = IteratorA0_;
|
||||
using IteratorA = IteratorA0;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB0 = IteratorB0_;
|
||||
using IteratorB = IteratorB0;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< Policy describing tuning details
|
||||
@ -141,6 +143,10 @@ public:
|
||||
///< Policy describing tuning details
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
///< Export Policy0 as the threadblock-level Mma's policy
|
||||
using Policy = Policy0;
|
||||
using Shape = Shape0;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
|
||||
@ -194,6 +200,10 @@ public:
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
|
@ -126,7 +126,9 @@ public:
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA0;
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB0;
|
||||
using Policy0 = Policy0_; ///< Policy describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
@ -139,6 +141,8 @@ public:
|
||||
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy describing tuning details
|
||||
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
|
||||
using Shape = Shape1;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
@ -195,6 +199,10 @@ public:
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
|
@ -128,7 +128,9 @@ public:
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA0;
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB0;
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||
|
||||
@ -141,6 +143,8 @@ public:
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
|
||||
using Shape = Shape1;
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||
@ -192,6 +196,10 @@ public:
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
/// Complex transform exports needed by higher-level kernels
|
||||
static ComplexTransform const kTransformA = kTransformA0;
|
||||
static ComplexTransform const kTransformB = kTransformB0;
|
||||
|
||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
|
@ -0,0 +1,153 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements several threadblock-swizzling functions for grouped kernels
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct GroupedThreadblockSwizzleBase {};
|
||||
|
||||
/// Helper for determining if a swizzling function is specialized for grouped operation
|
||||
template <typename ThreadblockSwizzle>
|
||||
struct IsGroupedSwizzle {
|
||||
static bool const value = cutlass::platform::is_base_of<GroupedThreadblockSwizzleBase, ThreadblockSwizzle>::value;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Swizzling function for grouped kernels
|
||||
template <typename ProblemVisitor_>
|
||||
struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase {
|
||||
|
||||
using ProblemVisitor = ProblemVisitor_;
|
||||
ProblemVisitor problem_visitor;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params,
|
||||
typename ProblemVisitor::SharedStorage& shared_storage,
|
||||
int block_idx) : problem_visitor(params, shared_storage, block_idx) {}
|
||||
|
||||
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||
CUTLASS_DEVICE
|
||||
GemmCoord get_tile_offset(int /*log_tile*/) const {
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
return GemmCoord(int(threadblock_idx / grid_shape.n()),
|
||||
int(threadblock_idx % grid_shape.n()),
|
||||
0);
|
||||
}
|
||||
|
||||
/// Dummy method to satisfy API for threadblock swizzling functions
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int get_log_tile(GemmCoord /*tiled_shape*/) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename ThreadblockShape,
|
||||
typename LayoutC,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||
int PrefetchTileCount = 128,
|
||||
int ThreadCount = PrefetchTileCount>
|
||||
struct GemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle<
|
||||
cutlass::gemm::kernel::GemmGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount,
|
||||
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value
|
||||
>
|
||||
> {
|
||||
using Base = GroupedThreadblockSwizzle<cutlass::gemm::kernel::GemmGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount,
|
||||
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value>>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params,
|
||||
typename Base::ProblemVisitor::SharedStorage& shared_storage,
|
||||
int block_idx) : Base(params, shared_storage, block_idx) {}
|
||||
};
|
||||
|
||||
template <
|
||||
typename ThreadblockShape,
|
||||
typename LayoutC,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||
int PrefetchTileCount = 128,
|
||||
int ThreadCount = PrefetchTileCount>
|
||||
struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle<
|
||||
cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount,
|
||||
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value
|
||||
>
|
||||
> {
|
||||
using Base = GroupedThreadblockSwizzle<cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount,
|
||||
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value>>;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params,
|
||||
typename Base::ProblemVisitor::SharedStorage& shared_storage,
|
||||
int block_idx) : Base(params, shared_storage, block_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@ -290,7 +290,6 @@ public:
|
||||
int available_sm_count=-1) {
|
||||
// Determine the number of blocks that would be launched to fill up a single
|
||||
// wave on the GPU with each SM having maximum occupancy.
|
||||
cudaDeviceProp properties;
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
if (result != cudaSuccess) {
|
||||
|
@ -114,7 +114,7 @@ struct GemmIdentityThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
auto n = tiled_shape.n();
|
||||
// Thresholds picked so that it doesn't cause too many no-op CTAs
|
||||
if (N >= 8 && n >= 6)
|
||||
@ -187,7 +187,7 @@ struct GemmHorizontalThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -228,7 +228,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -284,7 +284,7 @@ struct GemmSplitKIdentityThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
auto n = tiled_shape.n();
|
||||
// Thresholds picked so that it doesn't cause too many no-op CTAs
|
||||
if (N >= 8 && n >= 6)
|
||||
@ -361,7 +361,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -412,7 +412,7 @@ struct GemvBatchedStridedThreadblockDefaultSwizzle {
|
||||
|
||||
/// Calculates optimal swizzle width
|
||||
CUTLASS_HOST_DEVICE
|
||||
int get_log_tile(GemmCoord tiled_shape) const {
|
||||
static int get_log_tile(GemmCoord tiled_shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user