Add grouped b2b GEMM (#970)
This commit is contained in:
parent
fde824af21
commit
87349d3496
@ -64,6 +64,7 @@ endforeach()
|
|||||||
foreach(FUSION_GEMM_EXAMPLE
|
foreach(FUSION_GEMM_EXAMPLE
|
||||||
fused_two_gemms_f16_sm75_rf
|
fused_two_gemms_f16_sm75_rf
|
||||||
fused_two_gemms_f16_sm75_shmem
|
fused_two_gemms_f16_sm75_shmem
|
||||||
|
fused_two_gemms_grouped_f16_sm80_rf
|
||||||
fused_two_gemms_f16_sm80_rf
|
fused_two_gemms_f16_sm80_rf
|
||||||
fused_two_gemms_f16_sm80_shmem
|
fused_two_gemms_f16_sm80_shmem
|
||||||
fused_two_gemms_s8_sm75_rf
|
fused_two_gemms_s8_sm75_rf
|
||||||
|
450
examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
Normal file
450
examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h
Normal file
@ -0,0 +1,450 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Containers for running grouped back-to-back GEMMs
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "cutlass/util/device_memory.h"
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
#include "cutlass/util/distribution.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||||
|
#include "cutlass/util/reference/device/gemm.h"
|
||||||
|
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||||
|
|
||||||
|
#include "reference/device/tensor_scale_bias.h"
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
#define CHECK_GT(val1, val2) \
|
||||||
|
if((val1) <= (val2)) \
|
||||||
|
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||||
|
#define CHECK_TRUE(val) \
|
||||||
|
if(!(val)) \
|
||||||
|
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename B2bGemm_>
|
||||||
|
struct B2bFusedGroupedGemmRun
|
||||||
|
{
|
||||||
|
|
||||||
|
using B2bGemm = B2bGemm_;
|
||||||
|
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
||||||
|
using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute;
|
||||||
|
|
||||||
|
/// Initialization
|
||||||
|
cutlass::Distribution::Kind init_A;
|
||||||
|
cutlass::Distribution::Kind init_B;
|
||||||
|
cutlass::Distribution::Kind init_C;
|
||||||
|
cutlass::Distribution::Kind init_Scale;
|
||||||
|
cutlass::Distribution::Kind init_Bias;
|
||||||
|
uint64_t seed;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
B2bFusedGroupedGemmRun(
|
||||||
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||||
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||||
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||||
|
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||||
|
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||||
|
uint64_t seed_ = 2080
|
||||||
|
):
|
||||||
|
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||||
|
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||||
|
|
||||||
|
/// Helper to initialize a tensor view
|
||||||
|
template <typename Element, typename Layout>
|
||||||
|
bool initialize_tensor(
|
||||||
|
cutlass::TensorView<Element, Layout> view,
|
||||||
|
cutlass::Distribution::Kind dist_kind,
|
||||||
|
uint64_t seed) {
|
||||||
|
|
||||||
|
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||||
|
|
||||||
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
|
view, seed, 2, -2, 0);
|
||||||
|
}
|
||||||
|
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||||
|
|
||||||
|
cutlass::reference::host::TensorFillIdentity(view);
|
||||||
|
}
|
||||||
|
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||||
|
|
||||||
|
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||||
|
}
|
||||||
|
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||||
|
|
||||||
|
cutlass::reference::host::BlockFillSequential(
|
||||||
|
view.data(), view.capacity());
|
||||||
|
}
|
||||||
|
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||||
|
cutlass::reference::host::TensorFill(view, Element(0));
|
||||||
|
}
|
||||||
|
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||||
|
cutlass::reference::host::TensorFill(view, Element(1));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::cerr << "Not implemented\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Executes one test
|
||||||
|
bool run(
|
||||||
|
std::vector<cutlass::gemm::GemmCoord> problem_sizes_0,
|
||||||
|
std::vector<cutlass::gemm::GemmCoord> problem_sizes_1,
|
||||||
|
ElementCompute alpha0 = ElementCompute(1),
|
||||||
|
ElementCompute beta0 = ElementCompute(0),
|
||||||
|
ElementCompute alpha1 = ElementCompute(1),
|
||||||
|
ElementCompute beta1 = ElementCompute(0),
|
||||||
|
bool relu = true,
|
||||||
|
int warm_ups = 1,
|
||||||
|
int runs = 100) {
|
||||||
|
|
||||||
|
using HostTensorA = cutlass::HostTensor<typename B2bGemm::ElementA, typename B2bGemm::LayoutA>;
|
||||||
|
using HostTensorB = cutlass::HostTensor<typename B2bGemm::ElementB, typename B2bGemm::LayoutB>;
|
||||||
|
using HostTensorC = cutlass::HostTensor<typename B2bGemm::ElementC, typename B2bGemm::LayoutC>;
|
||||||
|
using HostTensorScale = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
||||||
|
using HostTensorZ = cutlass::HostTensor<ElementAccumulator, typename B2bGemm::LayoutC>;
|
||||||
|
using HostTensorBias = cutlass::HostTensor<ElementCompute, typename B2bGemm::LayoutC>;
|
||||||
|
|
||||||
|
int problem_count = (int)problem_sizes_0.size();
|
||||||
|
|
||||||
|
std::vector<HostTensorA> host_tensor_A0(problem_count);
|
||||||
|
std::vector<HostTensorB> host_tensor_B0(problem_count);
|
||||||
|
std::vector<HostTensorC> host_tensor_C0(problem_count);
|
||||||
|
std::vector<HostTensorScale> host_tensor_Scale0(problem_count);
|
||||||
|
std::vector<HostTensorScale> host_tensor_Bias0(problem_count);
|
||||||
|
std::vector<HostTensorB> host_tensor_B1(problem_count);
|
||||||
|
std::vector<HostTensorC> host_tensor_C1(problem_count);
|
||||||
|
std::vector<HostTensorBias> host_tensor_Bias1(problem_count);
|
||||||
|
std::vector<HostTensorC> host_tensor_D1(problem_count);
|
||||||
|
std::vector<HostTensorZ> host_tensor_Z(problem_count);
|
||||||
|
std::vector<HostTensorC> host_tensor_ref_D0(problem_count);
|
||||||
|
std::vector<HostTensorC> host_tensor_ref_D1(problem_count);
|
||||||
|
|
||||||
|
std::vector<typename HostTensorA::TensorRef> ref_A0(problem_count);
|
||||||
|
std::vector<typename HostTensorB::TensorRef> ref_B0(problem_count);
|
||||||
|
std::vector<typename HostTensorC::TensorRef> ref_C0(problem_count);
|
||||||
|
std::vector<typename HostTensorScale::TensorRef> ref_Scale0(problem_count);
|
||||||
|
std::vector<typename HostTensorScale::TensorRef> ref_Bias0(problem_count);
|
||||||
|
std::vector<typename HostTensorB::TensorRef> ref_B1(problem_count);
|
||||||
|
std::vector<typename HostTensorC::TensorRef> ref_C1(problem_count);
|
||||||
|
std::vector<typename HostTensorBias::TensorRef> ref_Bias1(problem_count);
|
||||||
|
std::vector<typename HostTensorC::TensorRef> ref_D1(problem_count);
|
||||||
|
std::vector<typename HostTensorZ::TensorRef> ref_Z(problem_count);
|
||||||
|
std::vector<typename HostTensorC::TensorRef> ref_ref_D0(problem_count);
|
||||||
|
std::vector<typename HostTensorC::TensorRef> ref_ref_D1(problem_count);
|
||||||
|
|
||||||
|
for (int i = 0; i < problem_count; ++i) {
|
||||||
|
//
|
||||||
|
// Allocate the GEMM workspace
|
||||||
|
//
|
||||||
|
|
||||||
|
auto problem_size_0 = problem_sizes_0[i];
|
||||||
|
auto problem_size_1 = problem_sizes_1[i];
|
||||||
|
|
||||||
|
host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk());
|
||||||
|
host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn());
|
||||||
|
host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn());
|
||||||
|
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||||
|
host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()});
|
||||||
|
host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()});
|
||||||
|
host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn());
|
||||||
|
host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn());
|
||||||
|
host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn());
|
||||||
|
host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn());
|
||||||
|
host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()});
|
||||||
|
host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn());
|
||||||
|
host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn());
|
||||||
|
|
||||||
|
CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019));
|
||||||
|
CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018));
|
||||||
|
CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017));
|
||||||
|
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||||
|
CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014));
|
||||||
|
CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013));
|
||||||
|
CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016));
|
||||||
|
CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015));
|
||||||
|
CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012));
|
||||||
|
|
||||||
|
cutlass::reference::host::TensorFill(
|
||||||
|
host_tensor_D1.at(i).host_view());
|
||||||
|
cutlass::reference::host::TensorFill(
|
||||||
|
host_tensor_ref_D0.at(i).host_view());
|
||||||
|
cutlass::reference::host::TensorFill(
|
||||||
|
host_tensor_ref_D1.at(i).host_view());
|
||||||
|
|
||||||
|
host_tensor_A0.at(i).sync_device();
|
||||||
|
host_tensor_B0.at(i).sync_device();
|
||||||
|
host_tensor_C0.at(i).sync_device();
|
||||||
|
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||||
|
host_tensor_Scale0.at(i).sync_device();
|
||||||
|
host_tensor_Bias0.at(i).sync_device();
|
||||||
|
host_tensor_B1.at(i).sync_device();
|
||||||
|
host_tensor_C1.at(i).sync_device();
|
||||||
|
host_tensor_Bias1.at(i).sync_device();
|
||||||
|
host_tensor_D1.at(i).sync_device();
|
||||||
|
host_tensor_ref_D0.at(i).sync_device();
|
||||||
|
host_tensor_ref_D1.at(i).sync_device();
|
||||||
|
|
||||||
|
ref_A0.at(i) = (host_tensor_A0.at(i).device_ref());
|
||||||
|
ref_B0.at(i) = (host_tensor_B0.at(i).device_ref());;
|
||||||
|
ref_C0.at(i) = (host_tensor_C0.at(i).device_ref());
|
||||||
|
if (alpha0 == ElementCompute(0)) //per-channel scale
|
||||||
|
ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref());
|
||||||
|
ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref());
|
||||||
|
ref_B1.at(i) = (host_tensor_B1.at(i).device_ref());
|
||||||
|
ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)};
|
||||||
|
ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref());
|
||||||
|
ref_D1.at(i) = (host_tensor_D1.at(i).device_ref());
|
||||||
|
ref_Z.at(i) = (host_tensor_Z.at(i).device_ref());
|
||||||
|
ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref());
|
||||||
|
ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref());
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Initialize the GEMM operator
|
||||||
|
//
|
||||||
|
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorA::TensorRef> device_ref_A0(problem_count);
|
||||||
|
device_ref_A0.copy_from_host(ref_A0.data());
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B0(problem_count);
|
||||||
|
device_ref_B0.copy_from_host(ref_B0.data());
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C0(problem_count);
|
||||||
|
device_ref_C0.copy_from_host(ref_C0.data());
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Scale0(problem_count);
|
||||||
|
device_ref_Scale0.copy_from_host(ref_Scale0.data());
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorScale::TensorRef> device_ref_Bias0(problem_count);
|
||||||
|
device_ref_Bias0.copy_from_host(ref_Bias0.data());
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorB::TensorRef> device_ref_B1(problem_count);
|
||||||
|
device_ref_B1.copy_from_host(ref_B1.data());
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_C1(problem_count);
|
||||||
|
device_ref_C1.copy_from_host(ref_C1.data());
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorBias::TensorRef> device_ref_Bias1(problem_count);
|
||||||
|
device_ref_Bias1.copy_from_host(ref_Bias1.data());
|
||||||
|
cutlass::DeviceAllocation<typename HostTensorC::TensorRef> device_ref_D1(problem_count);
|
||||||
|
device_ref_D1.copy_from_host(ref_D1.data());
|
||||||
|
|
||||||
|
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_0(problem_count);
|
||||||
|
device_problem_sizes_0.copy_from_host(problem_sizes_0.data());
|
||||||
|
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> device_problem_sizes_1(problem_count);
|
||||||
|
device_problem_sizes_1.copy_from_host(problem_sizes_1.data());
|
||||||
|
|
||||||
|
B2bGemm b2b_gemm_op;
|
||||||
|
|
||||||
|
int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count);
|
||||||
|
if (!threadblock_count) {
|
||||||
|
std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
typename B2bGemm::Arguments arguments{
|
||||||
|
problem_count,
|
||||||
|
device_problem_sizes_0.get(),
|
||||||
|
device_problem_sizes_1.get(),
|
||||||
|
device_ref_A0.get(),
|
||||||
|
device_ref_B0.get(),
|
||||||
|
device_ref_C0.get(),
|
||||||
|
device_ref_Scale0.get(),
|
||||||
|
device_ref_Bias0.get(),
|
||||||
|
device_ref_B1.get(),
|
||||||
|
device_ref_C1.get(),
|
||||||
|
device_ref_D1.get(),
|
||||||
|
{alpha0, beta0},
|
||||||
|
{alpha1, beta1},
|
||||||
|
threadblock_count
|
||||||
|
};
|
||||||
|
|
||||||
|
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
||||||
|
|
||||||
|
if(status != cutlass::Status::kSuccess) {
|
||||||
|
std::cout << "Problem sizes not supported.\n"
|
||||||
|
<< "Requirments:\n"
|
||||||
|
<< " problem_size_0.M = problem_size_1.M\n"
|
||||||
|
<< " problem_size_0.N = problem_size_1.K\n"
|
||||||
|
<< " ThreadblockShape0::kN = problem_size_0.N\n"
|
||||||
|
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
status = b2b_gemm_op.initialize(arguments);
|
||||||
|
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
|
||||||
|
for(int i = 0; i < warm_ups; i++) {
|
||||||
|
status = b2b_gemm_op();
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Run the GEMM
|
||||||
|
//
|
||||||
|
|
||||||
|
cudaEvent_t start, stop;
|
||||||
|
cudaEventCreate(&start);
|
||||||
|
cudaEventCreate(&stop);
|
||||||
|
|
||||||
|
cudaEventRecord(start);
|
||||||
|
|
||||||
|
for(int i = 0; i < runs; i++) {
|
||||||
|
status = b2b_gemm_op();
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaEventRecord(stop);
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
float gemmTime;
|
||||||
|
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||||
|
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
||||||
|
|
||||||
|
for (int i = 0; i < problem_count; ++i) {
|
||||||
|
host_tensor_D1.at(i).sync_host();;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Verify
|
||||||
|
//
|
||||||
|
|
||||||
|
cutlass::reference::device::Gemm<
|
||||||
|
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||||
|
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||||
|
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||||
|
ElementAccumulator, ElementAccumulator>
|
||||||
|
reference_gemm_0;
|
||||||
|
|
||||||
|
cutlass::reference::device::Gemm<
|
||||||
|
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||||
|
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||||
|
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||||
|
ElementAccumulator>
|
||||||
|
reference_gemm_1;
|
||||||
|
|
||||||
|
auto problem_size_0 = problem_sizes_0[i];
|
||||||
|
auto problem_size_1 = problem_sizes_1[i];
|
||||||
|
|
||||||
|
reference_gemm_0(
|
||||||
|
problem_size_0,
|
||||||
|
ElementAccumulator(1), //intermediate alpha=1
|
||||||
|
ref_A0.at(i),
|
||||||
|
ref_B0.at(i),
|
||||||
|
ElementAccumulator(0), //beta = 0
|
||||||
|
ref_Z.at(i),
|
||||||
|
ref_Z.at(i),
|
||||||
|
ElementAccumulator(0)
|
||||||
|
);
|
||||||
|
|
||||||
|
cutlass::reference::device::TensorScaleBiasGemm<
|
||||||
|
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||||
|
ElementCompute, typename B2bGemm::LayoutC
|
||||||
|
> (
|
||||||
|
problem_size_0,
|
||||||
|
ref_Z.at(i),
|
||||||
|
ref_ref_D0.at(i),
|
||||||
|
alpha0,
|
||||||
|
ref_Scale0.at(i),
|
||||||
|
ref_Bias0.at(i)
|
||||||
|
);
|
||||||
|
|
||||||
|
if(relu) {
|
||||||
|
cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view());
|
||||||
|
}
|
||||||
|
|
||||||
|
reference_gemm_1(
|
||||||
|
problem_size_1,
|
||||||
|
alpha1,
|
||||||
|
ref_ref_D0.at(i),
|
||||||
|
ref_B1.at(i),
|
||||||
|
beta1,
|
||||||
|
{host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||||
|
ref_ref_D1.at(i)
|
||||||
|
);
|
||||||
|
if(relu) {
|
||||||
|
cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view());
|
||||||
|
}
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
host_tensor_ref_D0.at(i).sync_host();
|
||||||
|
host_tensor_ref_D1.at(i).sync_host();
|
||||||
|
|
||||||
|
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0);
|
||||||
|
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0);
|
||||||
|
CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0);
|
||||||
|
|
||||||
|
bool passed = cutlass::reference::host::TensorEquals(
|
||||||
|
host_tensor_ref_D1.at(i).host_view(),
|
||||||
|
host_tensor_D1.at(i).host_view());
|
||||||
|
|
||||||
|
CHECK_TRUE(passed);
|
||||||
|
if (!passed)
|
||||||
|
{
|
||||||
|
|
||||||
|
std::stringstream fname;
|
||||||
|
|
||||||
|
fname << "error_B2bGemm_device_fused.txt";
|
||||||
|
std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl;
|
||||||
|
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||||
|
|
||||||
|
std::ofstream file(fname.str());
|
||||||
|
|
||||||
|
file
|
||||||
|
<< "GEMM " << i << " in group\n"
|
||||||
|
<< "A0 =\n" << host_tensor_A0.at(i).host_view()
|
||||||
|
<< "\nB0 =\n" << host_tensor_B0.at(i).host_view()
|
||||||
|
<< "\nC0 =\n" << host_tensor_C0.at(i).host_view()
|
||||||
|
<< "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n"
|
||||||
|
<< "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n"
|
||||||
|
<< "\nB1 =\n" << host_tensor_B1.at(i).host_view()
|
||||||
|
<< "\nC1 =\n" << host_tensor_C1.at(i).host_view()
|
||||||
|
<< "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n"
|
||||||
|
<< "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view()
|
||||||
|
<< "\nComputed =\n" << host_tensor_D1.at(i).host_view();
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
@ -185,96 +185,7 @@ class B2bGemm {
|
|||||||
SmemAccumulator
|
SmemAccumulator
|
||||||
>::B2bGemmKernel;
|
>::B2bGemmKernel;
|
||||||
|
|
||||||
/// Argument structure
|
using Arguments = typename B2bGemmKernel::Arguments;
|
||||||
struct Arguments {
|
|
||||||
|
|
||||||
//
|
|
||||||
// Data members
|
|
||||||
//
|
|
||||||
|
|
||||||
GemmUniversalMode mode;
|
|
||||||
GemmCoord problem_size_0;
|
|
||||||
GemmCoord problem_size_1;
|
|
||||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
|
||||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
|
||||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
|
||||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
|
|
||||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
|
|
||||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
|
||||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
|
||||||
TensorRef<ElementC, LayoutC> ref_D1;
|
|
||||||
int64_t batch_stride_A0;
|
|
||||||
int64_t batch_stride_B0;
|
|
||||||
int64_t batch_stride_B1;
|
|
||||||
int64_t batch_stride_C1;
|
|
||||||
int64_t batch_stride_D1;
|
|
||||||
int64_t batch_stride_Bias0;
|
|
||||||
int64_t batch_stride_Scale0;
|
|
||||||
typename EpilogueOutputOp0::Params epilogue0;
|
|
||||||
typename EpilogueOutputOp1::Params epilogue1;
|
|
||||||
int batch_count;
|
|
||||||
|
|
||||||
//
|
|
||||||
// Methods
|
|
||||||
//
|
|
||||||
|
|
||||||
/// Default ctor
|
|
||||||
CUTLASS_HOST_DEVICE
|
|
||||||
Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Constructs an Arguments structure
|
|
||||||
CUTLASS_HOST_DEVICE
|
|
||||||
Arguments(
|
|
||||||
GemmUniversalMode mode_,
|
|
||||||
GemmCoord problem_size_0_,
|
|
||||||
GemmCoord problem_size_1_,
|
|
||||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
|
||||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
|
||||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
|
||||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
|
|
||||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
|
|
||||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
|
||||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
|
||||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
|
||||||
int64_t batch_stride_A0_,
|
|
||||||
int64_t batch_stride_B0_,
|
|
||||||
int64_t batch_stride_B1_,
|
|
||||||
int64_t batch_stride_C1_,
|
|
||||||
int64_t batch_stride_D1_,
|
|
||||||
int64_t batch_stride_Bias0_,
|
|
||||||
int64_t batch_stride_Scale0_,
|
|
||||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
|
||||||
typename EpilogueOutputOp0::Params(),
|
|
||||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
|
||||||
typename EpilogueOutputOp1::Params(),
|
|
||||||
int batch_count_ = 1
|
|
||||||
):
|
|
||||||
mode(mode_),
|
|
||||||
problem_size_0(problem_size_0_),
|
|
||||||
problem_size_1(problem_size_1_),
|
|
||||||
ref_A0(ref_A0_),
|
|
||||||
ref_B0(ref_B0_),
|
|
||||||
ref_C0(ref_C0_),
|
|
||||||
ref_Scale0(ref_Scale0_),
|
|
||||||
ref_Bias0(ref_Bias0_),
|
|
||||||
ref_B1(ref_B1_),
|
|
||||||
ref_C1(ref_C1_),
|
|
||||||
ref_D1(ref_D1_),
|
|
||||||
batch_stride_A0(batch_stride_A0_),
|
|
||||||
batch_stride_B0(batch_stride_B0_),
|
|
||||||
batch_stride_B1(batch_stride_B1_),
|
|
||||||
batch_stride_C1(batch_stride_C1_),
|
|
||||||
batch_stride_D1(batch_stride_D1_),
|
|
||||||
batch_stride_Bias0(batch_stride_Bias0_),
|
|
||||||
batch_stride_Scale0(batch_stride_Scale0_),
|
|
||||||
epilogue0(epilogue0_),
|
|
||||||
epilogue1(epilogue1_),
|
|
||||||
batch_count(batch_count_) {
|
|
||||||
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
@ -0,0 +1,297 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/device/base_grouped.h"
|
||||||
|
#include "cutlass/gemm/device/gemm.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/command_line.h"
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/gemm.h"
|
||||||
|
|
||||||
|
#include "device/b2b_gemm.h"
|
||||||
|
#include "kernel/default_b2b_gemm.h"
|
||||||
|
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||||
|
#include "b2b_grouped_gemm_run.h"
|
||||||
|
#include "test_run.h"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_0;
|
||||||
|
std::vector<cutlass::gemm::GemmCoord> gemm_f16_sm80_problem_sizes_1;
|
||||||
|
|
||||||
|
// Constraints:
|
||||||
|
// 1. Warp shape N must equal thread block shape N
|
||||||
|
// 2. Problem size N must equal thread block shape N
|
||||||
|
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||||
|
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||||
|
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||||
|
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||||
|
|
||||||
|
// Command line options parsing
|
||||||
|
struct Options {
|
||||||
|
|
||||||
|
bool help;
|
||||||
|
bool error;
|
||||||
|
bool reference_check;
|
||||||
|
int alignment = 8;
|
||||||
|
|
||||||
|
std::vector<cutlass::gemm::GemmCoord> problem_sizes0;
|
||||||
|
std::vector<cutlass::gemm::GemmCoord> problem_sizes1;
|
||||||
|
|
||||||
|
int problem_count;
|
||||||
|
bool verbose;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
Options():
|
||||||
|
help(false),
|
||||||
|
error(false),
|
||||||
|
reference_check(true),
|
||||||
|
problem_count(15),
|
||||||
|
verbose(false)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
// Parses the command line
|
||||||
|
void parse(int argc, char const **args) {
|
||||||
|
cutlass::CommandLine cmd(argc, args);
|
||||||
|
|
||||||
|
if (cmd.check_cmd_line_flag("help")) {
|
||||||
|
help = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.get_cmd_line_argument("problems", problem_count, 15);
|
||||||
|
cmd.get_cmd_line_argument("reference-check", reference_check, true);
|
||||||
|
cmd.get_cmd_line_argument("verbose", verbose, false);
|
||||||
|
|
||||||
|
randomize_problems(cmd);
|
||||||
|
}
|
||||||
|
|
||||||
|
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||||
|
|
||||||
|
//
|
||||||
|
// For now, randomly choose the problem sizes.
|
||||||
|
//
|
||||||
|
|
||||||
|
int cmd_line_m = -1;
|
||||||
|
int cmd_line_k = -1;
|
||||||
|
|
||||||
|
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||||
|
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||||
|
|
||||||
|
problem_sizes0.reserve(problem_count);
|
||||||
|
problem_sizes1.reserve(problem_count);
|
||||||
|
|
||||||
|
for (int i = 0; i < problem_count; ++i) {
|
||||||
|
|
||||||
|
int m = cmd_line_m;
|
||||||
|
int k = cmd_line_k;
|
||||||
|
|
||||||
|
if (m < 1) {
|
||||||
|
m = alignment * ((rand() % 256) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (k < 1) {
|
||||||
|
k = alignment * ((rand() % 256) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k);
|
||||||
|
cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN);
|
||||||
|
|
||||||
|
problem_sizes0.push_back(problem0);
|
||||||
|
problem_sizes1.push_back(problem1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verbose) {
|
||||||
|
print_problem_sizes();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prints the usage statement.
|
||||||
|
std::ostream & print_usage(std::ostream &out) const {
|
||||||
|
|
||||||
|
out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n"
|
||||||
|
<< " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n"
|
||||||
|
<< " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n"
|
||||||
|
<< " back-to-back GEMMs are subject to.s"
|
||||||
|
<< "Options:\n\n"
|
||||||
|
<< " --help If specified, displays this usage statement.\n\n"
|
||||||
|
<< " --problems=<int> Number of individual GEMM problems (default: --problems=15)\n"
|
||||||
|
<< " --m=<int> Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n"
|
||||||
|
<< " --k=<int> Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n"
|
||||||
|
<< " --verbose=<bool> If true, prints problem sizes.\n";
|
||||||
|
|
||||||
|
out << "\n\nExamples:\n\n"
|
||||||
|
|
||||||
|
<< "# Runs a grouped B2b GEMM with 10 random problem sizes\n"
|
||||||
|
<< "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n";
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_problem_sizes() {
|
||||||
|
std::cout << std::endl;
|
||||||
|
std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl;
|
||||||
|
for (int i = 0; i < problem_count; ++i) {
|
||||||
|
cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i);
|
||||||
|
cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i);
|
||||||
|
std::cout << "Problem " << i
|
||||||
|
<< "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k()
|
||||||
|
<< "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k()
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
bool run_fused_grouped_gemm_f16_sm80_rf_res() {
|
||||||
|
|
||||||
|
using ElementOutput = cutlass::half_t;
|
||||||
|
using ElementAccumulator = cutlass::half_t;
|
||||||
|
using ElementCompute = cutlass::half_t;
|
||||||
|
|
||||||
|
ElementCompute alpha0 = ElementCompute(1);
|
||||||
|
//Fused kernel has built-in bias, setting beta=0
|
||||||
|
ElementCompute beta0 = ElementCompute(0);
|
||||||
|
ElementCompute alpha1 = ElementCompute(1);
|
||||||
|
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||||
|
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
|
||||||
|
using EpilogueOutputOp0 =
|
||||||
|
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||||
|
ElementOutput,
|
||||||
|
InstructionShape::kM * InstructionShape::kN / 32,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementCompute,
|
||||||
|
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||||
|
>;
|
||||||
|
|
||||||
|
using EpilogueOutputOp1 =
|
||||||
|
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||||
|
ElementOutput,
|
||||||
|
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementCompute,
|
||||||
|
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||||
|
>;
|
||||||
|
|
||||||
|
using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle<
|
||||||
|
ThreadblockShape0,
|
||||||
|
cutlass::layout::RowMajor // LayoutC
|
||||||
|
>;
|
||||||
|
|
||||||
|
const int kAlignment = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||||
|
const int kStages = 3;
|
||||||
|
using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm<
|
||||||
|
cutlass::half_t,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
kAlignment,
|
||||||
|
cutlass::half_t,
|
||||||
|
cutlass::layout::ColumnMajor,
|
||||||
|
kAlignment,
|
||||||
|
cutlass::half_t,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::arch::Sm80,
|
||||||
|
ThreadblockShape0,
|
||||||
|
ThreadblockShape1,
|
||||||
|
WarpShape0,
|
||||||
|
WarpShape1,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOutputOp0,
|
||||||
|
EpilogueOutputOp1,
|
||||||
|
GroupedThreadblockSwizzle,
|
||||||
|
kStages,
|
||||||
|
cutlass::arch::OpMultiplyAdd
|
||||||
|
>::B2bGemmKernel;
|
||||||
|
|
||||||
|
using B2bGemm = cutlass::gemm::device::BaseGrouped<B2bGemmKernel>;
|
||||||
|
|
||||||
|
B2bFusedGroupedGemmRun<B2bGemm> fusedGemm;
|
||||||
|
|
||||||
|
std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n";
|
||||||
|
bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1);
|
||||||
|
if(passed)
|
||||||
|
std::cout << "Pass\n";
|
||||||
|
else
|
||||||
|
std::cout << "Fail\n";
|
||||||
|
|
||||||
|
return passed;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char const **args) {
|
||||||
|
|
||||||
|
//
|
||||||
|
// Parse options
|
||||||
|
//
|
||||||
|
|
||||||
|
Options options;
|
||||||
|
|
||||||
|
options.parse(argc, args);
|
||||||
|
|
||||||
|
if (options.help) {
|
||||||
|
options.print_usage(std::cout) << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.error) {
|
||||||
|
std::cerr << "Aborting execution." << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0;
|
||||||
|
gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1;
|
||||||
|
|
||||||
|
std::vector<bool (*)()>funcs = {
|
||||||
|
&run_fused_grouped_gemm_f16_sm80_rf_res
|
||||||
|
};
|
||||||
|
|
||||||
|
return testRun(80, funcs, "grouped gemm f16 RF residency");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
@ -40,12 +40,60 @@
|
|||||||
#include "cutlass/matrix_coord.h"
|
#include "cutlass/matrix_coord.h"
|
||||||
#include "cutlass/semaphore.h"
|
#include "cutlass/semaphore.h"
|
||||||
|
|
||||||
|
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
|
||||||
|
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
namespace cutlass {
|
namespace cutlass {
|
||||||
namespace gemm {
|
namespace gemm {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
/// Utility struct for returning the type of the problem visitor used by the swizzling function,
|
||||||
|
/// if it is a grouped swizzling function, or a default visitor. This is used only for defining
|
||||||
|
/// the parameters of the problem visitor used in GroupedParams.
|
||||||
|
template <
|
||||||
|
typename B2bMma_,
|
||||||
|
typename ThreadblockSwizzle_,
|
||||||
|
typename Enable = void
|
||||||
|
>
|
||||||
|
struct ProblemVisitorOrDefault;
|
||||||
|
|
||||||
|
/// Return a generic problem visitor for GEMM problems
|
||||||
|
template <
|
||||||
|
typename B2bMma_,
|
||||||
|
typename ThreadblockSwizzle_
|
||||||
|
>
|
||||||
|
struct ProblemVisitorOrDefault<B2bMma_,
|
||||||
|
ThreadblockSwizzle_,
|
||||||
|
typename platform::enable_if<
|
||||||
|
! cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
||||||
|
>::type> {
|
||||||
|
using value = B2bGemmGroupedProblemVisitor<typename B2bMma_::Shape,
|
||||||
|
GroupScheduleMode::kDeviceOnly,
|
||||||
|
128,
|
||||||
|
128,
|
||||||
|
platform::is_same<typename B2bMma_::LayoutC,
|
||||||
|
cutlass::layout::ColumnMajor>::value>;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Return the problem visitor specified by the swizzling function
|
||||||
|
template <
|
||||||
|
typename B2bMma_,
|
||||||
|
typename ThreadblockSwizzle_
|
||||||
|
>
|
||||||
|
struct ProblemVisitorOrDefault<B2bMma_,
|
||||||
|
ThreadblockSwizzle_,
|
||||||
|
typename platform::enable_if<
|
||||||
|
cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle_>::value
|
||||||
|
>::type> {
|
||||||
|
using value = typename ThreadblockSwizzle_::ProblemVisitor;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <
|
template <
|
||||||
@ -72,10 +120,169 @@ struct B2bGemm {
|
|||||||
|
|
||||||
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||||
|
|
||||||
|
/// Data types needed for higher-level containers. In some cases, a single type must be exposed
|
||||||
|
/// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from
|
||||||
|
/// the second GEMM (other than for ElementA/ElementB)
|
||||||
|
using ElementA = typename B2bMma::IteratorA0::Element;
|
||||||
|
using LayoutA = typename B2bMma::IteratorA0::Layout;
|
||||||
|
using ElementB = typename B2bMma::IteratorB0::Element;
|
||||||
|
using LayoutB = typename B2bMma::IteratorB0::Layout;
|
||||||
|
|
||||||
|
static ComplexTransform const kTransformA = B2bMma::kTransformA;
|
||||||
|
static ComplexTransform const kTransformB = B2bMma::kTransformB;
|
||||||
|
using Operator = typename B2bMma::Operator0;
|
||||||
|
|
||||||
|
using OperatorClass = typename Operator::OperatorClass;
|
||||||
|
using ThreadblockShape = typename B2bMma::Shape0;
|
||||||
|
using WarpShape = typename Operator::Shape;
|
||||||
|
using InstructionShape = typename Operator::InstructionShape;
|
||||||
|
using ArchTag = typename B2bMma::ArchTag;
|
||||||
|
|
||||||
|
static int const kStages = B2bMma::kStages;
|
||||||
|
static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements;
|
||||||
|
static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements;
|
||||||
|
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||||
|
|
||||||
|
using Mma = B2bMma;
|
||||||
|
using EpilogueOutputOp = OutputOp1;
|
||||||
|
|
||||||
/// Warp count (concept: GemmShape)
|
/// Warp count (concept: GemmShape)
|
||||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||||
static int const kThreadCount = 32 * WarpCount0::kCount;
|
static int const kThreadCount = 32 * WarpCount0::kCount;
|
||||||
|
|
||||||
|
/// Argument structure
|
||||||
|
struct Arguments {
|
||||||
|
|
||||||
|
//
|
||||||
|
// Data members
|
||||||
|
//
|
||||||
|
|
||||||
|
GemmUniversalMode mode;
|
||||||
|
GemmCoord problem_size_0;
|
||||||
|
GemmCoord problem_size_1;
|
||||||
|
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||||
|
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
|
||||||
|
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||||
|
int64_t batch_stride_A0;
|
||||||
|
int64_t batch_stride_B0;
|
||||||
|
int64_t batch_stride_B1;
|
||||||
|
int64_t batch_stride_C1;
|
||||||
|
int64_t batch_stride_D1;
|
||||||
|
int64_t batch_stride_Bias0;
|
||||||
|
int64_t batch_stride_Scale0;
|
||||||
|
typename OutputOp0::Params epilogue0;
|
||||||
|
typename OutputOp1::Params epilogue1;
|
||||||
|
int batch_count;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Default ctor
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Arguments() : mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {}
|
||||||
|
|
||||||
|
/// Constructs an Arguments structure
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Arguments(
|
||||||
|
GemmUniversalMode mode_,
|
||||||
|
GemmCoord problem_size_0_,
|
||||||
|
GemmCoord problem_size_1_,
|
||||||
|
typename B2bMma::IteratorA0::TensorRef ref_A0_,
|
||||||
|
typename B2bMma::IteratorB0::TensorRef ref_B0_,
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef ref_C0_,
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_,
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_,
|
||||||
|
typename B2bMma::IteratorB1::TensorRef ref_B1_,
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef ref_C1_,
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef ref_D1_,
|
||||||
|
int64_t batch_stride_A0_,
|
||||||
|
int64_t batch_stride_B0_,
|
||||||
|
int64_t batch_stride_B1_,
|
||||||
|
int64_t batch_stride_C1_,
|
||||||
|
int64_t batch_stride_D1_,
|
||||||
|
int64_t batch_stride_Bias0_,
|
||||||
|
int64_t batch_stride_Scale0_,
|
||||||
|
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
||||||
|
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
||||||
|
int batch_count_ = 1
|
||||||
|
):
|
||||||
|
mode(mode_),
|
||||||
|
problem_size_0(problem_size_0_),
|
||||||
|
problem_size_1(problem_size_1_),
|
||||||
|
ref_A0(ref_A0_),
|
||||||
|
ref_B0(ref_B0_),
|
||||||
|
ref_C0(ref_C0_),
|
||||||
|
ref_Scale0(ref_Scale0_),
|
||||||
|
ref_Bias0(ref_Bias0_),
|
||||||
|
ref_B1(ref_B1_),
|
||||||
|
ref_C1(ref_C1_),
|
||||||
|
ref_D1(ref_D1_),
|
||||||
|
batch_stride_A0(batch_stride_A0_),
|
||||||
|
batch_stride_B0(batch_stride_B0_),
|
||||||
|
batch_stride_B1(batch_stride_B1_),
|
||||||
|
batch_stride_C1(batch_stride_C1_),
|
||||||
|
batch_stride_D1(batch_stride_D1_),
|
||||||
|
batch_stride_Bias0(batch_stride_Bias0_),
|
||||||
|
batch_stride_Scale0(batch_stride_Scale0_),
|
||||||
|
epilogue0(epilogue0_),
|
||||||
|
epilogue1(epilogue1_),
|
||||||
|
batch_count(batch_count_) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Arguments structure for grouped B2B problems
|
||||||
|
struct GroupedArguments {
|
||||||
|
GemmCoord* problem_size_0;
|
||||||
|
GemmCoord* problem_size_1;
|
||||||
|
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
||||||
|
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
||||||
|
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||||
|
|
||||||
|
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||||
|
// the parameter here is not a pointer.
|
||||||
|
typename OutputOp0::Params epilogue0;
|
||||||
|
typename OutputOp1::Params epilogue1;
|
||||||
|
|
||||||
|
int problem_count;
|
||||||
|
int threadblock_count;
|
||||||
|
GemmCoord* host_problem_sizes;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
GroupedArguments(
|
||||||
|
int problem_count,
|
||||||
|
GemmCoord* problem_size_0_,
|
||||||
|
GemmCoord* problem_size_1_,
|
||||||
|
typename B2bMma::IteratorA0::TensorRef* ref_A0_,
|
||||||
|
typename B2bMma::IteratorB0::TensorRef* ref_B0_,
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_C0_,
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_,
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_,
|
||||||
|
typename B2bMma::IteratorB1::TensorRef* ref_B1_,
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_C1_,
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_D1_,
|
||||||
|
typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(),
|
||||||
|
typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(),
|
||||||
|
int threadblock_count = 0
|
||||||
|
) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_),
|
||||||
|
ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_),
|
||||||
|
ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_),
|
||||||
|
ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_),
|
||||||
|
problem_count(problem_count),
|
||||||
|
threadblock_count(threadblock_count)
|
||||||
|
{}
|
||||||
|
};
|
||||||
|
|
||||||
/// Parameters structure
|
/// Parameters structure
|
||||||
struct Params {
|
struct Params {
|
||||||
cutlass::gemm::GemmUniversalMode mode;
|
cutlass::gemm::GemmUniversalMode mode;
|
||||||
@ -149,7 +356,7 @@ struct B2bGemm {
|
|||||||
problem_size_0(problem_size_0),
|
problem_size_0(problem_size_0),
|
||||||
problem_size_1(problem_size_1),
|
problem_size_1(problem_size_1),
|
||||||
grid_tiled_shape(grid_tiled_shape),
|
grid_tiled_shape(grid_tiled_shape),
|
||||||
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)),
|
||||||
params_A0(ref_A0.layout()),
|
params_A0(ref_A0.layout()),
|
||||||
ref_A0(ref_A0),
|
ref_A0(ref_A0),
|
||||||
params_B0(ref_B0.layout()),
|
params_B0(ref_B0.layout()),
|
||||||
@ -185,6 +392,81 @@ struct B2bGemm {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct GroupedParams {
|
||||||
|
cutlass::gemm::GemmCoord* problem_size_0;
|
||||||
|
cutlass::gemm::GemmCoord* problem_size_1;
|
||||||
|
cutlass::gemm::GemmCoord* grid_tiled_shape;
|
||||||
|
typename B2bMma::IteratorA0::TensorRef* ref_A0;
|
||||||
|
typename B2bMma::IteratorB0::TensorRef* ref_B0;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_C0;
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0;
|
||||||
|
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0;
|
||||||
|
typename B2bMma::IteratorB1::TensorRef* ref_B1;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_C1;
|
||||||
|
typename Epilogue::OutputTileIterator::TensorRef* ref_D1;
|
||||||
|
|
||||||
|
// Epilogue params remain constant across all problmes in the group. Thus,
|
||||||
|
// the parameter here is not a pointer.
|
||||||
|
typename OutputOp0::Params output_op_0;
|
||||||
|
typename OutputOp1::Params output_op_1;
|
||||||
|
|
||||||
|
using ProblemVisitor = typename detail::ProblemVisitorOrDefault<B2bMma, ThreadblockSwizzle>::value;
|
||||||
|
typename ProblemVisitor::Params problem_visitor;
|
||||||
|
int threadblock_count;
|
||||||
|
int* workspace;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
GroupedParams() {}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
GroupedParams(
|
||||||
|
GroupedArguments const &args,
|
||||||
|
void *workspace = nullptr,
|
||||||
|
int tile_count = 0
|
||||||
|
) :
|
||||||
|
problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1),
|
||||||
|
ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0),
|
||||||
|
ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1),
|
||||||
|
output_op_0(args.epilogue0), output_op_1(args.epilogue1),
|
||||||
|
problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count),
|
||||||
|
threadblock_count(args.threadblock_count),
|
||||||
|
workspace(reinterpret_cast<int*>(workspace)) {}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
void transpose() {
|
||||||
|
// Only row-major outputs are currently supported, so no transpose is performed
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns non-grouped paramaters to be used as input to the kernel-level
|
||||||
|
/// operator for the problem indicated by problem_visitor.
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Params to_single_params(const ProblemVisitor& problem_visitor) const {
|
||||||
|
GemmCoord problem_size0 = problem_visitor.problem_size0();
|
||||||
|
GemmCoord problem_size1 = problem_visitor.problem_size1();
|
||||||
|
int32_t idx = problem_visitor.problem_index();
|
||||||
|
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1);
|
||||||
|
|
||||||
|
return Params(
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
problem_size0,
|
||||||
|
problem_size1,
|
||||||
|
grid_shape,
|
||||||
|
ref_A0[idx],
|
||||||
|
ref_B0[idx],
|
||||||
|
ref_C0[idx],
|
||||||
|
ref_Scale0[idx],
|
||||||
|
ref_Bias0[idx],
|
||||||
|
ref_B1[idx],
|
||||||
|
ref_C1[idx],
|
||||||
|
ref_D1[idx],
|
||||||
|
0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported
|
||||||
|
output_op_0,
|
||||||
|
output_op_1,
|
||||||
|
workspace
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Shared memory storage structure
|
/// Shared memory storage structure
|
||||||
union SharedStorage {
|
union SharedStorage {
|
||||||
typename B2bMma::B2bMmaSharedStorage main_loop;
|
typename B2bMma::B2bMmaSharedStorage main_loop;
|
||||||
@ -266,9 +548,13 @@ struct B2bGemm {
|
|||||||
/// Executes one GEMM
|
/// Executes one GEMM
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||||
|
|
||||||
// Compute threadblock location
|
|
||||||
ThreadblockSwizzle threadblock_swizzle;
|
ThreadblockSwizzle threadblock_swizzle;
|
||||||
|
run_with_swizzle(params, shared_storage, threadblock_swizzle);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Executes one GEMM with an externally-provided swizzling function
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) {
|
||||||
|
|
||||||
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
||||||
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||||
@ -391,14 +677,17 @@ struct B2bGemm {
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Main loop
|
// Main loop
|
||||||
//
|
//
|
||||||
|
|
||||||
OutputOp0 output_op_0(params.output_op_0);
|
OutputOp0 output_op_0(params.output_op_0);
|
||||||
|
|
||||||
|
if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle<ThreadblockSwizzle>::value) {
|
||||||
|
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
// Construct thread-scoped matrix multiply
|
// Construct thread-scoped matrix multiply
|
||||||
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
|
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n());
|
||||||
|
|
||||||
|
@ -0,0 +1,157 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/*! \file
|
||||||
|
\brief Scheduler for grouped B2b GEMMs
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/gemm.h"
|
||||||
|
#include "cutlass/matrix_coord.h"
|
||||||
|
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||||
|
template <typename ThreadblockShape,
|
||||||
|
GroupScheduleMode GroupScheduleMode_,
|
||||||
|
int PrefetchTileCount,
|
||||||
|
int ThreadCount,
|
||||||
|
bool Transposed = false>
|
||||||
|
struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor<
|
||||||
|
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
|
||||||
|
ThreadblockShape,
|
||||||
|
GroupScheduleMode_,
|
||||||
|
PrefetchTileCount,
|
||||||
|
ThreadCount> {
|
||||||
|
|
||||||
|
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||||
|
using Base = GroupedProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||||
|
using BaseParams = typename Base::Params;
|
||||||
|
using SharedStorage = typename Base::SharedStorage;
|
||||||
|
static bool const kTransposed = Transposed;
|
||||||
|
|
||||||
|
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||||
|
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
cutlass::gemm::GemmCoord const *problem_sizes0;
|
||||||
|
cutlass::gemm::GemmCoord const *problem_sizes1;
|
||||||
|
int32_t problem_count;
|
||||||
|
void const *workspace;
|
||||||
|
int32_t tile_count;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Ctor
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Params(): problem_sizes0(nullptr), problem_sizes1(nullptr),
|
||||||
|
problem_count(0), workspace(nullptr), tile_count(0) { }
|
||||||
|
|
||||||
|
/// Ctor
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Params(
|
||||||
|
cutlass::gemm::GemmCoord const *problem_sizes0,
|
||||||
|
cutlass::gemm::GemmCoord const *problem_sizes1,
|
||||||
|
int32_t problem_count,
|
||||||
|
void const *workspace = nullptr,
|
||||||
|
int32_t tile_count = 0
|
||||||
|
):
|
||||||
|
problem_sizes0(problem_sizes0),
|
||||||
|
problem_sizes1(problem_sizes1),
|
||||||
|
problem_count(problem_count),
|
||||||
|
workspace(workspace),
|
||||||
|
tile_count(tile_count)
|
||||||
|
{}
|
||||||
|
|
||||||
|
/// Convert the B2b-GEMM-specific parameters to those used by the base class
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
BaseParams to_base() const {
|
||||||
|
return BaseParams(// Set problem_sizes as problem_sizes0 because these determine
|
||||||
|
// shape of the grid used in the non-grouped B2b GEMM
|
||||||
|
problem_sizes0,
|
||||||
|
problem_count,
|
||||||
|
workspace,
|
||||||
|
tile_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
B2bGemmGroupedProblemVisitor(
|
||||||
|
Params const ¶ms_,
|
||||||
|
SharedStorage &shared_storage_,
|
||||||
|
int32_t block_idx
|
||||||
|
): Base (
|
||||||
|
params_.to_base(),
|
||||||
|
shared_storage_, block_idx),
|
||||||
|
problem_sizes0(params_.problem_sizes0),
|
||||||
|
problem_sizes1(params_.problem_sizes1)
|
||||||
|
{}
|
||||||
|
|
||||||
|
/// Returns the problem size 0 for the current problem
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
cutlass::gemm::GemmCoord problem_size0() const {
|
||||||
|
GemmCoord problem = problem_sizes0[this->problem_idx];
|
||||||
|
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||||
|
return problem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the problem size 1 for the current problem
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
cutlass::gemm::GemmCoord problem_size1() const {
|
||||||
|
GemmCoord problem = problem_sizes1[this->problem_idx];
|
||||||
|
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||||
|
return problem;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -63,7 +63,9 @@
|
|||||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||||
|
|
||||||
#include "kernel/b2b_gemm.h"
|
#include "kernel/b2b_gemm.h"
|
||||||
|
#include "kernel/grouped.h"
|
||||||
#include "threadblock/default_b2b_mma.h"
|
#include "threadblock/default_b2b_mma.h"
|
||||||
|
#include "threadblock/grouped_threadblock_swizzle.h"
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
@ -73,6 +75,9 @@ namespace kernel {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle<T>;
|
||||||
|
|
||||||
template <
|
template <
|
||||||
/// Element type for A matrix operand
|
/// Element type for A matrix operand
|
||||||
typename ElementA_,
|
typename ElementA_,
|
||||||
@ -117,7 +122,9 @@ template <
|
|||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator,
|
typename Operator,
|
||||||
/// Stage accumulator in shared memory
|
/// Stage accumulator in shared memory
|
||||||
bool SmemAccumulator = false
|
bool SmemAccumulator = false,
|
||||||
|
/// Whether or not the operation is grouped
|
||||||
|
typename Enable = void
|
||||||
>
|
>
|
||||||
struct DefaultB2bGemm;
|
struct DefaultB2bGemm;
|
||||||
|
|
||||||
@ -166,7 +173,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
|||||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||||
WarpShape0, WarpShape1, InstructionShape,
|
WarpShape0, WarpShape1, InstructionShape,
|
||||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||||
Operator> {
|
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||||
@ -186,6 +193,71 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
|||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Partial specialization for Ampere Architecture with grouped operation
|
||||||
|
template <
|
||||||
|
/// Element type for A matrix operand
|
||||||
|
typename ElementA,
|
||||||
|
/// Layout type for A matrix operand
|
||||||
|
typename LayoutA,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int kAlignmentA,
|
||||||
|
/// Element type for B matrix operand
|
||||||
|
typename ElementB,
|
||||||
|
/// Layout type for B matrix operand
|
||||||
|
typename LayoutB,
|
||||||
|
/// Access granularity of A matrix in units of elements
|
||||||
|
int kAlignmentB,
|
||||||
|
/// Element type for C and D matrix operands
|
||||||
|
typename ElementC,
|
||||||
|
/// Element type for internal accumulation
|
||||||
|
typename ElementAccumulator,
|
||||||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||||||
|
typename ThreadblockShape0,
|
||||||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||||||
|
typename ThreadblockShape1,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename WarpShape0,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename WarpShape1,
|
||||||
|
/// Warp-level tile size (concept: GemmShape)
|
||||||
|
typename InstructionShape,
|
||||||
|
/// Epilogue output operator
|
||||||
|
typename EpilogueOutputOp0,
|
||||||
|
/// Epilogue output operator
|
||||||
|
typename EpilogueOutputOp1,
|
||||||
|
/// Threadblock-level swizzling operator
|
||||||
|
typename ThreadblockSwizzle,
|
||||||
|
/// Number of stages used in the pipelined mainloop
|
||||||
|
int Stages,
|
||||||
|
/// Operation performed by GEMM
|
||||||
|
typename Operator>
|
||||||
|
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||||
|
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||||
|
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||||
|
WarpShape0, WarpShape1, InstructionShape,
|
||||||
|
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||||
|
Operator, false, typename platform::enable_if<IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||||
|
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||||
|
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||||
|
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||||
|
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||||
|
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||||
|
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
||||||
|
|
||||||
|
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||||
|
|
||||||
|
/// Define the epilogue
|
||||||
|
using Epilogue =
|
||||||
|
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||||
|
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||||
|
EpilogueOutputOp1::kCount>::Epilogue;
|
||||||
|
|
||||||
|
/// Define the kernel-level GEMM operator.
|
||||||
|
using UnderlyingB2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
|
|
||||||
|
using B2bGemmKernel = kernel::GroupedKernel<UnderlyingB2bGemmKernel>;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
@ -242,7 +314,9 @@ struct DefaultB2bGemm<
|
|||||||
EpilogueOutputOp1,
|
EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
2,
|
2,
|
||||||
Operator
|
Operator,
|
||||||
|
false,
|
||||||
|
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type
|
||||||
> {
|
> {
|
||||||
|
|
||||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||||
@ -324,7 +398,8 @@ struct DefaultB2bGemm<
|
|||||||
arch::OpClassTensorOp, arch::Sm80,
|
arch::OpClassTensorOp, arch::Sm80,
|
||||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle, Stages, Operator> {
|
ThreadblockSwizzle, Stages,
|
||||||
|
Operator, false, typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
@ -393,7 +468,8 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
|||||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle, 2, Operator> {
|
ThreadblockSwizzle, 2, Operator, false,
|
||||||
|
typename platform::enable_if<!IsGroupedSwizzle<ThreadblockSwizzle>::value>::type> {
|
||||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
|
168
examples/13_two_tensor_op_fusion/kernel/grouped.h
Normal file
168
examples/13_two_tensor_op_fusion/kernel/grouped.h
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
/*! \file
|
||||||
|
\brief High-level interface for running a grouped version of a CUTLASS kernel
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/fast_math.h"
|
||||||
|
#include "cutlass/gemm/gemm.h"
|
||||||
|
#include "cutlass/matrix_coord.h"
|
||||||
|
#include "cutlass/complex.h"
|
||||||
|
#include "cutlass/semaphore.h"
|
||||||
|
|
||||||
|
#include "cutlass/layout/matrix.h"
|
||||||
|
#include "cutlass/trace.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// High-level interface for running a grouped version of a CUTLASS kernel
|
||||||
|
template <
|
||||||
|
typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate
|
||||||
|
>
|
||||||
|
struct GroupedKernel {
|
||||||
|
public:
|
||||||
|
|
||||||
|
using BaseKernel = BaseKernel_;
|
||||||
|
using Epilogue = typename BaseKernel::Epilogue;
|
||||||
|
|
||||||
|
/// Types that need to be exported to work properly with device::BaseGrouped
|
||||||
|
using ElementA = typename BaseKernel::ElementA;
|
||||||
|
using LayoutA = typename BaseKernel::LayoutA;
|
||||||
|
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||||
|
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||||
|
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||||
|
|
||||||
|
using ElementB = typename BaseKernel::ElementB;
|
||||||
|
using LayoutB = typename BaseKernel::LayoutB;
|
||||||
|
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||||
|
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||||
|
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||||
|
|
||||||
|
using ElementC = typename BaseKernel::ElementC;
|
||||||
|
using LayoutC = typename BaseKernel::LayoutC;
|
||||||
|
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||||
|
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||||
|
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||||
|
|
||||||
|
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||||
|
|
||||||
|
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||||
|
using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle;
|
||||||
|
|
||||||
|
using Operator = typename BaseKernel::Operator;
|
||||||
|
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||||
|
|
||||||
|
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||||
|
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||||
|
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||||
|
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||||
|
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||||
|
using WarpShape = typename BaseKernel::WarpShape;
|
||||||
|
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||||
|
static int const kStages = BaseKernel::Mma::kStages;
|
||||||
|
|
||||||
|
using Mma = typename BaseKernel::Mma;
|
||||||
|
|
||||||
|
using Arguments = typename BaseKernel::GroupedArguments;
|
||||||
|
using Params = typename BaseKernel::GroupedParams;
|
||||||
|
using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor;
|
||||||
|
|
||||||
|
static int const kThreadCount = BaseKernel::kThreadCount;
|
||||||
|
|
||||||
|
/// Shared memory storage structure
|
||||||
|
struct SharedStorage {
|
||||||
|
typename BaseKernel::SharedStorage kernel;
|
||||||
|
|
||||||
|
// ProblemVisitor shared storage can't be overlapped with others
|
||||||
|
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
GroupedKernel() { }
|
||||||
|
|
||||||
|
/// Determines whether kernel satisfies alignment
|
||||||
|
static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) {
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Status can_implement(Arguments const &args) {
|
||||||
|
return Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Executes a kernel-level GEMM in a loop
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void operator()(Params ¶ms, SharedStorage &shared_storage) {
|
||||||
|
|
||||||
|
ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||||
|
|
||||||
|
if (ProblemVisitor::kTransposed) {
|
||||||
|
params.transpose();
|
||||||
|
}
|
||||||
|
|
||||||
|
BaseKernel mma;
|
||||||
|
|
||||||
|
// Outer 'persistent' loop to iterate over tiles
|
||||||
|
while (swizzle.problem_visitor.next_tile()) {
|
||||||
|
|
||||||
|
typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor);
|
||||||
|
mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle);
|
||||||
|
|
||||||
|
// Next tile
|
||||||
|
swizzle.problem_visitor.advance(gridDim.x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -119,8 +119,10 @@ public:
|
|||||||
using Shape0 = Shape0_;
|
using Shape0 = Shape0_;
|
||||||
///< Iterates over tiles of A operand in global memory
|
///< Iterates over tiles of A operand in global memory
|
||||||
using IteratorA0 = IteratorA0_;
|
using IteratorA0 = IteratorA0_;
|
||||||
|
using IteratorA = IteratorA0;
|
||||||
///< Iterates over tiles of B operand in global memory
|
///< Iterates over tiles of B operand in global memory
|
||||||
using IteratorB0 = IteratorB0_;
|
using IteratorB0 = IteratorB0_;
|
||||||
|
using IteratorB = IteratorB0;
|
||||||
///< Policy describing tuning details
|
///< Policy describing tuning details
|
||||||
using Policy0 = Policy0_;
|
using Policy0 = Policy0_;
|
||||||
|
|
||||||
@ -140,6 +142,10 @@ public:
|
|||||||
///< Policy describing tuning details
|
///< Policy describing tuning details
|
||||||
using Policy1 = Policy1_;
|
using Policy1 = Policy1_;
|
||||||
|
|
||||||
|
///< Export Policy0 as the threadblock-level Mma's policy
|
||||||
|
using Policy = Policy0;
|
||||||
|
using Shape = Shape0;
|
||||||
|
|
||||||
using SmemIteratorB1 = SmemIteratorB1_;
|
using SmemIteratorB1 = SmemIteratorB1_;
|
||||||
|
|
||||||
///< Data type of accumulator matrix
|
///< Data type of accumulator matrix
|
||||||
@ -188,6 +194,10 @@ public:
|
|||||||
/// Complex transform on B operand
|
/// Complex transform on B operand
|
||||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||||
|
|
||||||
|
/// Complex transform exports needed by higher-level kernels
|
||||||
|
static ComplexTransform const kTransformA = kTransformA0;
|
||||||
|
static ComplexTransform const kTransformB = kTransformB0;
|
||||||
|
|
||||||
/// Internal structure exposed for introspection.
|
/// Internal structure exposed for introspection.
|
||||||
struct Detail {
|
struct Detail {
|
||||||
|
|
||||||
|
@ -121,8 +121,10 @@ public:
|
|||||||
using Shape0 = Shape0_;
|
using Shape0 = Shape0_;
|
||||||
///< Iterates over tiles of A operand in global memory
|
///< Iterates over tiles of A operand in global memory
|
||||||
using IteratorA0 = IteratorA0_;
|
using IteratorA0 = IteratorA0_;
|
||||||
|
using IteratorA = IteratorA0;
|
||||||
///< Iterates over tiles of B operand in global memory
|
///< Iterates over tiles of B operand in global memory
|
||||||
using IteratorB0 = IteratorB0_;
|
using IteratorB0 = IteratorB0_;
|
||||||
|
using IteratorB = IteratorB0;
|
||||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||||
///< Policy describing tuning details
|
///< Policy describing tuning details
|
||||||
@ -141,6 +143,10 @@ public:
|
|||||||
///< Policy describing tuning details
|
///< Policy describing tuning details
|
||||||
using Policy1 = Policy1_;
|
using Policy1 = Policy1_;
|
||||||
|
|
||||||
|
///< Export Policy0 as the threadblock-level Mma's policy
|
||||||
|
using Policy = Policy0;
|
||||||
|
using Shape = Shape0;
|
||||||
|
|
||||||
using SmemIteratorB1 = SmemIteratorB1_;
|
using SmemIteratorB1 = SmemIteratorB1_;
|
||||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||||
|
|
||||||
@ -194,6 +200,10 @@ public:
|
|||||||
/// Complex transform on B operand
|
/// Complex transform on B operand
|
||||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||||
|
|
||||||
|
/// Complex transform exports needed by higher-level kernels
|
||||||
|
static ComplexTransform const kTransformA = kTransformA0;
|
||||||
|
static ComplexTransform const kTransformB = kTransformB0;
|
||||||
|
|
||||||
/// Internal structure exposed for introspection.
|
/// Internal structure exposed for introspection.
|
||||||
struct Detail {
|
struct Detail {
|
||||||
|
|
||||||
|
@ -126,7 +126,9 @@ public:
|
|||||||
|
|
||||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||||
|
using IteratorA = IteratorA0;
|
||||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||||
|
using IteratorB = IteratorB0;
|
||||||
using Policy0 = Policy0_; ///< Policy describing tuning details
|
using Policy0 = Policy0_; ///< Policy describing tuning details
|
||||||
|
|
||||||
using SmemIteratorA0 = SmemIteratorA0_;
|
using SmemIteratorA0 = SmemIteratorA0_;
|
||||||
@ -139,6 +141,8 @@ public:
|
|||||||
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
|
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
|
||||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||||
using Policy1 = Policy1_; ///< Policy describing tuning details
|
using Policy1 = Policy1_; ///< Policy describing tuning details
|
||||||
|
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
|
||||||
|
using Shape = Shape1;
|
||||||
|
|
||||||
using SmemIteratorB1 = SmemIteratorB1_;
|
using SmemIteratorB1 = SmemIteratorB1_;
|
||||||
|
|
||||||
@ -195,6 +199,10 @@ public:
|
|||||||
/// Complex transform on B1 operand
|
/// Complex transform on B1 operand
|
||||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||||
|
|
||||||
|
/// Complex transform exports needed by higher-level kernels
|
||||||
|
static ComplexTransform const kTransformA = kTransformA0;
|
||||||
|
static ComplexTransform const kTransformB = kTransformB0;
|
||||||
|
|
||||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||||
|
|
||||||
|
@ -128,7 +128,9 @@ public:
|
|||||||
|
|
||||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||||
|
using IteratorA = IteratorA0;
|
||||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||||
|
using IteratorB = IteratorB0;
|
||||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||||
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
using Policy0 = Policy0_; ///< Policy0 describing tuning details
|
||||||
|
|
||||||
@ -141,6 +143,8 @@ public:
|
|||||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||||
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
using Policy1 = Policy1_; ///< Policy1 describing tuning details
|
||||||
|
using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy
|
||||||
|
using Shape = Shape1;
|
||||||
|
|
||||||
using SmemIteratorB1 = SmemIteratorB1_;
|
using SmemIteratorB1 = SmemIteratorB1_;
|
||||||
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory
|
||||||
@ -192,6 +196,10 @@ public:
|
|||||||
/// Complex transform on B1 operand
|
/// Complex transform on B1 operand
|
||||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||||
|
|
||||||
|
/// Complex transform exports needed by higher-level kernels
|
||||||
|
static ComplexTransform const kTransformA = kTransformA0;
|
||||||
|
static ComplexTransform const kTransformB = kTransformB0;
|
||||||
|
|
||||||
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
/// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||||
|
|
||||||
|
@ -0,0 +1,153 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Implements several threadblock-swizzling functions for grouped kernels
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||||
|
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||||
|
#include "kernel/b2b_gemm_grouped_problem_visitor.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace threadblock {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
struct GroupedThreadblockSwizzleBase {};
|
||||||
|
|
||||||
|
/// Helper for determining if a swizzling function is specialized for grouped operation
|
||||||
|
template <typename ThreadblockSwizzle>
|
||||||
|
struct IsGroupedSwizzle {
|
||||||
|
static bool const value = cutlass::platform::is_base_of<GroupedThreadblockSwizzleBase, ThreadblockSwizzle>::value;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// Swizzling function for grouped kernels
|
||||||
|
template <typename ProblemVisitor_>
|
||||||
|
struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase {
|
||||||
|
|
||||||
|
using ProblemVisitor = ProblemVisitor_;
|
||||||
|
ProblemVisitor problem_visitor;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params,
|
||||||
|
typename ProblemVisitor::SharedStorage& shared_storage,
|
||||||
|
int block_idx) : problem_visitor(params, shared_storage, block_idx) {}
|
||||||
|
|
||||||
|
/// Obtains the threadblock offset (in units of threadblock-scoped tiles)
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
GemmCoord get_tile_offset(int /*log_tile*/) const {
|
||||||
|
GemmCoord problem_size = problem_visitor.problem_size();
|
||||||
|
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||||
|
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||||
|
|
||||||
|
return GemmCoord(int(threadblock_idx / grid_shape.n()),
|
||||||
|
int(threadblock_idx % grid_shape.n()),
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dummy method to satisfy API for threadblock swizzling functions
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static int get_log_tile(GemmCoord /*tiled_shape*/) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename ThreadblockShape,
|
||||||
|
typename LayoutC,
|
||||||
|
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||||
|
int PrefetchTileCount = 128,
|
||||||
|
int ThreadCount = PrefetchTileCount>
|
||||||
|
struct GemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle<
|
||||||
|
cutlass::gemm::kernel::GemmGroupedProblemVisitor<
|
||||||
|
ThreadblockShape,
|
||||||
|
GroupScheduleMode_,
|
||||||
|
PrefetchTileCount,
|
||||||
|
ThreadCount,
|
||||||
|
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value
|
||||||
|
>
|
||||||
|
> {
|
||||||
|
using Base = GroupedThreadblockSwizzle<cutlass::gemm::kernel::GemmGroupedProblemVisitor<
|
||||||
|
ThreadblockShape,
|
||||||
|
GroupScheduleMode_,
|
||||||
|
PrefetchTileCount,
|
||||||
|
ThreadCount,
|
||||||
|
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value>>;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
GemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params,
|
||||||
|
typename Base::ProblemVisitor::SharedStorage& shared_storage,
|
||||||
|
int block_idx) : Base(params, shared_storage, block_idx) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename ThreadblockShape,
|
||||||
|
typename LayoutC,
|
||||||
|
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||||||
|
int PrefetchTileCount = 128,
|
||||||
|
int ThreadCount = PrefetchTileCount>
|
||||||
|
struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle<
|
||||||
|
cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
|
||||||
|
ThreadblockShape,
|
||||||
|
GroupScheduleMode_,
|
||||||
|
PrefetchTileCount,
|
||||||
|
ThreadCount,
|
||||||
|
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value
|
||||||
|
>
|
||||||
|
> {
|
||||||
|
using Base = GroupedThreadblockSwizzle<cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor<
|
||||||
|
ThreadblockShape,
|
||||||
|
GroupScheduleMode_,
|
||||||
|
PrefetchTileCount,
|
||||||
|
ThreadCount,
|
||||||
|
platform::is_same<LayoutC, cutlass::layout::ColumnMajor>::value>>;
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params,
|
||||||
|
typename Base::ProblemVisitor::SharedStorage& shared_storage,
|
||||||
|
int block_idx) : Base(params, shared_storage, block_idx) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace threadblock
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
@ -290,7 +290,6 @@ public:
|
|||||||
int available_sm_count=-1) {
|
int available_sm_count=-1) {
|
||||||
// Determine the number of blocks that would be launched to fill up a single
|
// Determine the number of blocks that would be launched to fill up a single
|
||||||
// wave on the GPU with each SM having maximum occupancy.
|
// wave on the GPU with each SM having maximum occupancy.
|
||||||
cudaDeviceProp properties;
|
|
||||||
int device_idx;
|
int device_idx;
|
||||||
cudaError_t result = cudaGetDevice(&device_idx);
|
cudaError_t result = cudaGetDevice(&device_idx);
|
||||||
if (result != cudaSuccess) {
|
if (result != cudaSuccess) {
|
||||||
|
@ -114,7 +114,7 @@ struct GemmIdentityThreadblockSwizzle {
|
|||||||
|
|
||||||
/// Calculates optimal swizzle width
|
/// Calculates optimal swizzle width
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
int get_log_tile(GemmCoord tiled_shape) const {
|
static int get_log_tile(GemmCoord tiled_shape) {
|
||||||
auto n = tiled_shape.n();
|
auto n = tiled_shape.n();
|
||||||
// Thresholds picked so that it doesn't cause too many no-op CTAs
|
// Thresholds picked so that it doesn't cause too many no-op CTAs
|
||||||
if (N >= 8 && n >= 6)
|
if (N >= 8 && n >= 6)
|
||||||
@ -187,7 +187,7 @@ struct GemmHorizontalThreadblockSwizzle {
|
|||||||
|
|
||||||
/// Calculates optimal swizzle width
|
/// Calculates optimal swizzle width
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
int get_log_tile(GemmCoord tiled_shape) const {
|
static int get_log_tile(GemmCoord tiled_shape) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,7 +228,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
|
|||||||
|
|
||||||
/// Calculates optimal swizzle width
|
/// Calculates optimal swizzle width
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
int get_log_tile(GemmCoord tiled_shape) const {
|
static int get_log_tile(GemmCoord tiled_shape) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,7 +284,7 @@ struct GemmSplitKIdentityThreadblockSwizzle {
|
|||||||
|
|
||||||
/// Calculates optimal swizzle width
|
/// Calculates optimal swizzle width
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
int get_log_tile(GemmCoord tiled_shape) const {
|
static int get_log_tile(GemmCoord tiled_shape) {
|
||||||
auto n = tiled_shape.n();
|
auto n = tiled_shape.n();
|
||||||
// Thresholds picked so that it doesn't cause too many no-op CTAs
|
// Thresholds picked so that it doesn't cause too many no-op CTAs
|
||||||
if (N >= 8 && n >= 6)
|
if (N >= 8 && n >= 6)
|
||||||
@ -361,7 +361,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle {
|
|||||||
|
|
||||||
/// Calculates optimal swizzle width
|
/// Calculates optimal swizzle width
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
int get_log_tile(GemmCoord tiled_shape) const {
|
static int get_log_tile(GemmCoord tiled_shape) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -412,7 +412,7 @@ struct GemvBatchedStridedThreadblockDefaultSwizzle {
|
|||||||
|
|
||||||
/// Calculates optimal swizzle width
|
/// Calculates optimal swizzle width
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
int get_log_tile(GemmCoord tiled_shape) const {
|
static int get_log_tile(GemmCoord tiled_shape) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user