added support of b2b bmm (#849)

* added support of b2b bmm

* fixed arguments and params structures

* added batch_count argument

* removed SplitKSerial and added new test case with b2b bmm

* fixed support of Kbatched and added new test case with batch stride

* added batch support for bias and scale

* make test

* small changes

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Aleksandr Pivovar 2023-04-15 05:20:02 +02:00 committed by GitHub
parent d572cc1aab
commit 4a68cf748e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 765 additions and 409 deletions

View File

@ -42,6 +42,7 @@
#include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h" #include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h" #include "reference/device/tensor_scale_bias.h"
@ -459,6 +460,20 @@ struct B2bFusedGemmRun
ElementCompute beta0 = ElementCompute(0), ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1), ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0), ElementCompute beta1 = ElementCompute(0),
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
// batch_count is used as split-k when mode is kGemm according
// to the GemmUniversal interface
int batch_count = 1,
int64_t batch_stride_A0 = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_C0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C1 = 0,
int64_t batch_stride_D1 = 0,
int64_t batch_stride_Bias0 = 0,
int64_t batch_stride_Scale0 = 0,
bool relu = true, bool relu = true,
int warm_ups = 1, int warm_ups = 1,
int runs = 100) { int runs = 100) {
@ -467,56 +482,62 @@ struct B2bFusedGemmRun
// Allocate the GEMM workspace // Allocate the GEMM workspace
// //
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementA, typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementB, typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementScaleBias, typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0; typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()}); tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementScaleBias, typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor< cutlass::HostTensor<
ElementAccumulator, ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementB, typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
ElementCompute,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -554,6 +575,7 @@ struct B2bFusedGemmRun
// //
typename B2bGemm::Arguments arguments{ typename B2bGemm::Arguments arguments{
mode,
problem_size_0, problem_size_0,
problem_size_1, problem_size_1,
tensor_A0.device_ref(), tensor_A0.device_ref(),
@ -564,8 +586,16 @@ struct B2bFusedGemmRun
tensor_B1.device_ref(), tensor_B1.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(), tensor_D1.device_ref(),
batch_stride_A0,
batch_stride_B0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0,
{alpha0, beta0}, {alpha0, beta0},
{alpha1, beta1}, {alpha1, beta1},
batch_count,
}; };
B2bGemm b2b_gemm_op; B2bGemm b2b_gemm_op;
@ -618,32 +648,31 @@ struct B2bFusedGemmRun
// Verify // Verify
// //
cutlass::reference::device::Gemm< cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA, typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB, typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC, ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator> ElementAccumulator, ElementAccumulator
reference_gemm_0; >(
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size_0, problem_size_0,
ElementAccumulator(1), //intermediate alpha=1 ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(), tensor_A0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B0.device_ref(), tensor_B0.device_ref(),
cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0 ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(), reference_Z0.device_ref(),
reference_Z0.device_ref(), reference_Z0.device_ref(),
ElementAccumulator(0) ElementAccumulator(0),
int(batch_count),
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_C0
); );
cutlass::reference::device::TensorScaleBiasGemm< cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias ElementCompute, typename B2bGemm::LayoutScaleBias
> ( > (
@ -652,25 +681,45 @@ struct B2bFusedGemmRun
reference_D0.device_ref(), reference_D0.device_ref(),
alpha0, alpha0,
tensor_Scale0.device_ref(), tensor_Scale0.device_ref(),
tensor_Bias0.device_ref() tensor_Bias0.device_ref(),
int(batch_count),
batch_stride_C0,
batch_stride_C0,
batch_stride_Scale0,
batch_stride_Bias0
); );
if(relu) { if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view()); cutlass::reference::device::TensorReLu(reference_D0.device_view());
} }
reference_gemm_1( cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size_1, problem_size_1,
alpha1, alpha1, //intermediate alpha=1
reference_D0.device_ref(), reference_D0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B1.device_ref(), tensor_B1.device_ref(),
beta1, cutlass::ComplexTransform::kNone,
beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref() reference_D1.device_ref(),
ElementAccumulator(0),
int(batch_count),
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1
); );
if(relu) { if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view()); cutlass::reference::device::TensorReLu(reference_D1.device_view());
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
reference_D0.sync_host(); reference_D0.sync_host();
reference_D1.sync_host(); reference_D1.sync_host();

View File

@ -43,6 +43,7 @@
#include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/host_reorder.h" #include "cutlass/util/host_reorder.h"
#include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h" #include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h" #include "reference/device/tensor_scale_bias.h"
@ -194,7 +195,6 @@ struct B2bInterleavedNonFusedGemmRun
typename Gemm1::ElementC, typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
@ -476,6 +476,21 @@ struct B2bInterleavedFusedGemmRun
ElementCompute beta0 = ElementCompute(0), ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1), ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0), ElementCompute beta1 = ElementCompute(0),
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
// batch_count is used as split-k when mode is kGemm according
// to the GemmUniversal interface
int batch_count = 1,
int64_t batch_stride_A0 = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_C0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C1 = 0,
int64_t batch_stride_D1 = 0,
int64_t batch_stride_Bias0 = 0,
int64_t batch_stride_Scale0 = 0,
bool relu = true, bool relu = true,
int warm_ups = 1, int warm_ups = 1,
int runs = 100) { int runs = 100) {
@ -484,64 +499,70 @@ struct B2bInterleavedFusedGemmRun
// Allocate the GEMM workspace // Allocate the GEMM workspace
// //
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementA, typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementB, typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementB, typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn()); typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementScaleBias, typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0; typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()}); tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementScaleBias, typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor< cutlass::HostTensor<
ElementAccumulator, ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementB, typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementB, typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn()); typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
//Reorder B0 //Reorder B0
cutlass::reorder_column<16>( cutlass::reorder_column<16>(
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
cutlass::reorder_column<InterleavedK_>( cutlass::reorder_column<InterleavedK_>(
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
cutlass::reference::host::TensorFill( cutlass::reference::host::TensorFill(
tensor_D1.host_view()); tensor_D1.host_view());
@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
tensor_D1.sync_device(); tensor_D1.sync_device();
reference_D0.sync_device(); reference_D0.sync_device();
reference_D1.sync_device(); reference_D1.sync_device();
// tensor_Bias0_batched.sync_device();
// //
// Initialize the GEMM operator // Initialize the GEMM operator
// //
typename B2bGemm::Arguments arguments{ typename B2bGemm::Arguments arguments{
mode,
problem_size_0, problem_size_0,
problem_size_1, problem_size_1,
tensor_A0.device_ref(), tensor_A0.device_ref(),
@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
tensor_B1_reordered.device_ref(), tensor_B1_reordered.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(), tensor_D1.device_ref(),
batch_stride_A0,
batch_stride_B0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0,
{alpha0, beta0}, {alpha0, beta0},
{alpha1, beta1}, {alpha1, beta1},
batch_count,
}; };
B2bGemm b2b_gemm_op; B2bGemm b2b_gemm_op;
@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
// Verify // Verify
// //
cutlass::reference::device::Gemm< cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA, typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB, typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC, ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator> ElementAccumulator, ElementAccumulator
reference_gemm_0; >(
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size_0, problem_size_0,
ElementAccumulator(1), //intermediate alpha=1 ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(), tensor_A0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B0.device_ref(), tensor_B0.device_ref(),
cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0 ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(), reference_Z0.device_ref(),
reference_Z0.device_ref(), reference_Z0.device_ref(),
ElementAccumulator(0) ElementAccumulator(0),
int(batch_count),
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_C0
); );
cutlass::reference::device::TensorScaleBiasGemm< cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias ElementCompute, typename B2bGemm::LayoutScaleBias
> ( > (
@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
reference_D0.device_ref(), reference_D0.device_ref(),
alpha0, alpha0,
tensor_Scale0.device_ref(), tensor_Scale0.device_ref(),
tensor_Bias0.device_ref() tensor_Bias0.device_ref(),
int(batch_count),
batch_stride_C0,
batch_stride_C0,
batch_stride_Scale0,
batch_stride_Bias0
); );
if(relu) { if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view()); cutlass::reference::device::TensorReLu(reference_D0.device_view());
} }
reference_gemm_1( cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size_1, problem_size_1,
alpha1, alpha1, //intermediate alpha=1
reference_D0.device_ref(), reference_D0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B1.device_ref(), tensor_B1.device_ref(),
beta1, cutlass::ComplexTransform::kNone,
beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref() reference_D1.device_ref(),
ElementAccumulator(0),
int(batch_count),
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1
); );
if(relu) { if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view()); cutlass::reference::device::TensorReLu(reference_D1.device_view());
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
reference_D0.sync_host(); reference_D0.sync_host();
reference_D1.sync_host(); reference_D1.sync_host();

View File

@ -119,8 +119,6 @@ template <
int AlignmentB = int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_, DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB, ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration< typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
@ -154,7 +152,6 @@ class B2bGemm {
static int const kAlignmentA = AlignmentA; static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB; static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp1::kCount; static int const kAlignmentC = EpilogueOutputOp1::kCount;
static bool const kSplitKSerial = SplitKSerial;
static ComplexTransform const kTransformA = ComplexTransform::kNone; static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone; static ComplexTransform const kTransformB = ComplexTransform::kNone;
@ -184,7 +181,6 @@ class B2bGemm {
EpilogueOutputOp1, EpilogueOutputOp1,
ThreadblockSwizzle, ThreadblockSwizzle,
kStages, kStages,
kSplitKSerial,
Operator, Operator,
SmemAccumulator SmemAccumulator
>::B2bGemmKernel; >::B2bGemmKernel;
@ -196,6 +192,7 @@ class B2bGemm {
// Data members // Data members
// //
GemmUniversalMode mode;
GemmCoord problem_size_0; GemmCoord problem_size_0;
GemmCoord problem_size_1; GemmCoord problem_size_1;
TensorRef<ElementA const, LayoutA> ref_A0; TensorRef<ElementA const, LayoutA> ref_A0;
@ -206,9 +203,16 @@ class B2bGemm {
TensorRef<ElementB const, LayoutB> ref_B1; TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1; TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1; TensorRef<ElementC, LayoutC> ref_D1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
typename EpilogueOutputOp0::Params epilogue0; typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1; typename EpilogueOutputOp1::Params epilogue1;
int split_k_slices; int batch_count;
// //
// Methods // Methods
@ -216,13 +220,14 @@ class B2bGemm {
/// Default ctor /// Default ctor
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) { Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
} }
/// Constructs an Arguments structure /// Constructs an Arguments structure
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Arguments( Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_0_, GemmCoord problem_size_0_,
GemmCoord problem_size_1_, GemmCoord problem_size_1_,
TensorRef<ElementA const, LayoutA> ref_A0_, TensorRef<ElementA const, LayoutA> ref_A0_,
@ -233,12 +238,20 @@ class B2bGemm {
TensorRef<ElementB const, LayoutB> ref_B1_, TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_, TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_, TensorRef<ElementC, LayoutC> ref_D1_,
int64_t batch_stride_A0_,
int64_t batch_stride_B0_,
int64_t batch_stride_B1_,
int64_t batch_stride_C1_,
int64_t batch_stride_D1_,
int64_t batch_stride_Bias0_,
int64_t batch_stride_Scale0_,
typename EpilogueOutputOp0::Params epilogue0_ = typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(), typename EpilogueOutputOp0::Params(),
typename EpilogueOutputOp1::Params epilogue1_ = typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(), typename EpilogueOutputOp1::Params(),
int split_k_slices_ = 1 int batch_count_ = 1
): ):
mode(mode_),
problem_size_0(problem_size_0_), problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_), problem_size_1(problem_size_1_),
ref_A0(ref_A0_), ref_A0(ref_A0_),
@ -249,9 +262,16 @@ class B2bGemm {
ref_B1(ref_B1_), ref_B1(ref_B1_),
ref_C1(ref_C1_), ref_C1(ref_C1_),
ref_D1(ref_D1_), ref_D1(ref_D1_),
batch_stride_A0(batch_stride_A0_),
batch_stride_B0(batch_stride_B0_),
batch_stride_B1(batch_stride_B1_),
batch_stride_C1(batch_stride_C1_),
batch_stride_D1(batch_stride_D1_),
batch_stride_Bias0(batch_stride_Bias0_),
batch_stride_Scale0(batch_stride_Scale0_),
epilogue0(epilogue0_), epilogue0(epilogue0_),
epilogue1(epilogue1_), epilogue1(epilogue1_),
split_k_slices(split_k_slices_) { batch_count(batch_count_) {
} }
}; };
@ -269,10 +289,6 @@ public:
/// Determines whether the GEMM can execute the given problem. /// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) { static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Status status = B2bGemmKernel::can_implement( Status status = B2bGemmKernel::can_implement(
args.problem_size_0, args.problem_size_0,
args.problem_size_1, args.problem_size_1,
@ -302,13 +318,7 @@ public:
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0, args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices); args.batch_count);
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
return bytes; return bytes;
} }
@ -322,36 +332,15 @@ public:
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0, args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices); args.batch_count);
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape( // cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
// args.problem_size_1, // args.problem_size_1,
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK}, // {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
// args.split_k_slices); // args.batch_count);
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// Initialize the Params structure // Initialize the Params structure
params_ = typename B2bGemmKernel::Params{ params_ = typename B2bGemmKernel::Params{
args.mode,
args.problem_size_0, args.problem_size_0,
args.problem_size_1, args.problem_size_1,
grid_shape, grid_shape,
@ -363,6 +352,13 @@ public:
args.ref_B1.non_const_ref(), args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(), args.ref_C1.non_const_ref(),
args.ref_D1, args.ref_D1,
args.batch_stride_A0,
args.batch_stride_B0,
args.batch_stride_B1,
args.batch_stride_C1,
args.batch_stride_D1,
args.batch_stride_Bias0,
args.batch_stride_Scale0,
args.epilogue0, args.epilogue0,
args.epilogue1, args.epilogue1,
static_cast<int *>(workspace), static_cast<int *>(workspace),
@ -374,12 +370,6 @@ public:
/// Lightweight update given a subset of arguments /// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) { Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); params_.ref_C0.reset(args.ref_C0.non_const_ref().data());

View File

@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
SmemAccumulator, SmemAccumulator,
16, 16,
16, 16,
false,
cutlass::arch::OpMultiplyAddSaturate cutlass::arch::OpMultiplyAddSaturate
>; >;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm; B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n"; std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); bool passed = fusedGemm.run(
gemm_s8_sm80_problem_size_0,
gemm_s8_sm80_problem_size_1,
alpha0,
beta0,
alpha1,
beta1
);
if(passed) if(passed)
std::cout << "Pass\n"; std::cout << "Pass\n";
else else
@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
return passed; return passed;
} }
bool run_fused_gemm_s8_sm80_rf_res_batch() {
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
const bool SmemAccumulator = false;
using B2bGemm = cutlass::gemm::device::B2bGemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
SmemAccumulator,
16,
16,
cutlass::arch::OpMultiplyAddSaturate
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
int batch_count = 2;
int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
int64_t batch_stride_Scale0 = 0;
std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
bool passed = fusedGemm.run(
gemm_s8_sm80_problem_size_0,
gemm_s8_sm80_problem_size_1,
alpha0,
beta0,
alpha1,
beta1,
cutlass::gemm::GemmUniversalMode::kBatched,
batch_count,
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0
);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return passed;
}
int main() { int main() {
std::vector<bool (*)()>funcs = { std::vector<bool (*)()>funcs = {
&run_nonfused_gemm_s8_sm80, &run_nonfused_gemm_s8_sm80,
&run_fused_gemm_s8_sm80_rf_res &run_fused_gemm_s8_sm80_rf_res,
&run_fused_gemm_s8_sm80_rf_res_batch
}; };
return testRun(80, funcs, "gemm int8 RF residency"); return testRun(80, funcs, "gemm int8 RF residency");
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
SmemAccumulator, SmemAccumulator,
16, 16,
16, 16,
false,
cutlass::arch::OpMultiplyAddSaturate cutlass::arch::OpMultiplyAddSaturate
>; >;

View File

@ -51,8 +51,7 @@ namespace kernel {
template < template <
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function typename ThreadblockSwizzle_ ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
> >
struct B2bGemm { struct B2bGemm {
@ -61,7 +60,17 @@ struct B2bGemm {
using OutputOp0 = typename B2bMma::OutputOp; using OutputOp0 = typename B2bMma::OutputOp;
using OutputOp1 = typename Epilogue::OutputOp; using OutputOp1 = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_; using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA0 = typename B2bMma::IteratorA0::Element;
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
using ElementB0 = typename B2bMma::IteratorB0::Element;
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
using ElementB1 = typename B2bMma::IteratorB1::Element;
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
/// Warp count (concept: GemmShape) /// Warp count (concept: GemmShape)
using WarpCount0 = typename B2bMma::WarpCount0; using WarpCount0 = typename B2bMma::WarpCount0;
@ -69,6 +78,7 @@ struct B2bGemm {
/// Parameters structure /// Parameters structure
struct Params { struct Params {
cutlass::gemm::GemmUniversalMode mode;
cutlass::gemm::GemmCoord problem_size_0; cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1; cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape; cutlass::gemm::GemmCoord grid_tiled_shape;
@ -89,6 +99,13 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef ref_D1; typename Epilogue::OutputTileIterator::TensorRef ref_D1;
typename OutputOp0::Params output_op_0; typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1; typename OutputOp1::Params output_op_1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
int *semaphore; int *semaphore;
int gemm_k_iterations_0; int gemm_k_iterations_0;
int gemm_k_size_0; int gemm_k_size_0;
@ -100,11 +117,12 @@ struct B2bGemm {
// //
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { } gemm_k_iterations_1(0), gemm_k_size_1(0) { }
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params( Params(
cutlass::gemm::GemmUniversalMode mode,
cutlass::gemm::GemmCoord const & problem_size_0, cutlass::gemm::GemmCoord const & problem_size_0,
cutlass::gemm::GemmCoord const & problem_size_1, cutlass::gemm::GemmCoord const & problem_size_1,
cutlass::gemm::GemmCoord const & grid_tiled_shape, cutlass::gemm::GemmCoord const & grid_tiled_shape,
@ -116,10 +134,18 @@ struct B2bGemm {
typename B2bMma::IteratorB1::TensorRef ref_B1, typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1, typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1, typename Epilogue::OutputTileIterator::TensorRef ref_D1,
int64_t batch_stride_A0,
int64_t batch_stride_B0,
int64_t batch_stride_B1,
int64_t batch_stride_C1,
int64_t batch_stride_D1,
int64_t batch_stride_Bias0,
int64_t batch_stride_Scale0,
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
int *workspace = nullptr int *workspace = nullptr
): ):
mode(mode),
problem_size_0(problem_size_0), problem_size_0(problem_size_0),
problem_size_1(problem_size_1), problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape), grid_tiled_shape(grid_tiled_shape),
@ -138,6 +164,13 @@ struct B2bGemm {
ref_C1(ref_C1), ref_C1(ref_C1),
params_D1(ref_D1.layout()), params_D1(ref_D1.layout()),
ref_D1(ref_D1), ref_D1(ref_D1),
batch_stride_A0(batch_stride_A0),
batch_stride_B0(batch_stride_B0),
batch_stride_B1(batch_stride_B1),
batch_stride_C1(batch_stride_C1),
batch_stride_D1(batch_stride_D1),
batch_stride_Bias0(batch_stride_Bias0),
batch_stride_Scale0(batch_stride_Scale0),
output_op_0(output_op_0), output_op_0(output_op_0),
output_op_1(output_op_1) { output_op_1(output_op_1) {
@ -247,37 +280,64 @@ struct B2bGemm {
return; return;
} }
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
int offset_k_0 = 0;
int offset_k_1 = 0;
int problem_size_k_0 = params.problem_size_0.k();
int problem_size_k_1 = params.problem_size_1.k();
if (params.mode == GemmUniversalMode::kGemm) {
// Problem size is a function of threadblock index in the K dimension
problem_size_k_0 = min(
problem_size_k_0,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Problem size is a function of threadblock index in the K dimension
problem_size_k_1 = min(
problem_size_k_1,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
}
// Compute initial location in logical coordinates // Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A0{ cutlass::MatrixCoord tb_offset_A0{
threadblock_tile_offset.m() * B2bMma::Shape0::kM, threadblock_tile_offset.m() * B2bMma::Shape0::kM,
threadblock_tile_offset.k() * params.gemm_k_size_0, offset_k_0,
}; };
cutlass::MatrixCoord tb_offset_B0{ cutlass::MatrixCoord tb_offset_B0{
threadblock_tile_offset.k() * params.gemm_k_size_0, offset_k_0,
threadblock_tile_offset.n() * B2bMma::Shape0::kN threadblock_tile_offset.n() * B2bMma::Shape0::kN
}; };
cutlass::MatrixCoord tb_offset_B1{ cutlass::MatrixCoord tb_offset_B1{
threadblock_tile_offset.k() * params.gemm_k_size_1, offset_k_1,
threadblock_tile_offset.n() * B2bMma::Shape1::kN threadblock_tile_offset.n() * B2bMma::Shape1::kN
}; };
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_0 = min(
params.problem_size_0.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Compute threadblock-scoped matrix multiply-add // Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_1 = min(
params.problem_size_1.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
// Compute threadblock-scoped matrix multiply-add // Compute threadblock-scoped matrix multiply-add
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// Compute position within threadblock // Compute position within threadblock
@ -286,26 +346,25 @@ struct B2bGemm {
// Construct iterators to A and B operands // Construct iterators to A and B operands
typename B2bMma::IteratorA0 iterator_A0( typename B2bMma::IteratorA0 iterator_A0(
params.params_A0, params.params_A0,
params.ref_A0.data(), ptr_A0,
{params.problem_size_0.m(), problem_size_k_0}, {params.problem_size_0.m(), problem_size_k_0},
thread_idx, thread_idx,
tb_offset_A0); tb_offset_A0);
typename B2bMma::IteratorB0 iterator_B0( typename B2bMma::IteratorB0 iterator_B0(
params.params_B0, params.params_B0,
params.ref_B0.data(), ptr_B0,
{problem_size_k_0, params.problem_size_0.n()}, {problem_size_k_0, params.problem_size_0.n()},
thread_idx, thread_idx,
tb_offset_B0); tb_offset_B0);
typename B2bMma::IteratorB1 iterator_B1( typename B2bMma::IteratorB1 iterator_B1(
params.params_B1, params.params_B1,
params.ref_B1.data(), ptr_B1,
{problem_size_k_1, params.problem_size_1.n()}, {problem_size_k_1, params.problem_size_1.n()},
thread_idx, thread_idx,
tb_offset_B1); tb_offset_B1);
// Broadcast the warp_id computed by lane 0 to ensure dependent code // Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform. // is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
@ -313,7 +372,7 @@ struct B2bGemm {
// Construct iterators to accumulator scale/bias vector // Construct iterators to accumulator scale/bias vector
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
params.ref_Scale0.data(), ptr_Scale0,
{1, params.problem_size_0.n()}, {1, params.problem_size_0.n()},
thread_idx, thread_idx,
warp_idx, warp_idx,
@ -323,7 +382,7 @@ struct B2bGemm {
); );
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
params.ref_Bias0.data(), ptr_Bias0,
{1, params.problem_size_0.n()}, {1, params.problem_size_0.n()},
thread_idx, thread_idx,
warp_idx, warp_idx,
@ -349,11 +408,9 @@ struct B2bGemm {
src_accum.clear(); src_accum.clear();
accumulators.clear(); accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
// Compute threadblock-scoped matrix multiply-add // Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
}
// //
// Epilogue // Epilogue
@ -376,23 +433,32 @@ struct B2bGemm {
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
// Construct the semaphore. // Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx); Semaphore semaphore(params.semaphore + block_idx, thread_idx);
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization // If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block. // Fetch the synchronization lock initially but do not block.
semaphore.fetch(); semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating // Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
} }
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
}
// Tile iterator loading from source tensor. // Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1( typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1, params.params_C1,
params.ref_C1.data(), ptr_C1,
params.problem_size_1.mn(), params.problem_size_1.mn(),
thread_idx, thread_idx,
threadblock_offset threadblock_offset
@ -401,7 +467,7 @@ struct B2bGemm {
// Tile iterator writing to destination tensor. // Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D1( typename Epilogue::OutputTileIterator iterator_D1(
params.params_D1, params.params_D1,
params.ref_D1.data(), ptr_D1,
params.problem_size_1.mn(), params.problem_size_1.mn(),
thread_idx, thread_idx,
threadblock_offset threadblock_offset
@ -414,7 +480,7 @@ struct B2bGemm {
lane_idx); lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction // Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor. // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) { if (threadblock_tile_offset.k()) {
@ -433,7 +499,7 @@ struct B2bGemm {
// Release the semaphore // Release the semaphore
// //
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0; int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
@ -457,4 +523,3 @@ struct B2bGemm {
} // namespace kernel } // namespace kernel
} // namespace gemm } // namespace gemm
} // namespace cutlass } // namespace cutlass

View File

@ -114,8 +114,6 @@ template <
typename ThreadblockSwizzle, typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop /// Number of stages used in the pipelined mainloop
int Stages, int Stages,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator, typename Operator,
/// Stage accumulator in shared memory /// Stage accumulator in shared memory
@ -161,16 +159,13 @@ template <
typename ThreadblockSwizzle, typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop /// Number of stages used in the pipelined mainloop
int Stages, int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator> typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1, arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape, WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial, EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator> { Operator> {
/// Define the threadblock-scoped matrix multiply-accumulate /// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
@ -188,7 +183,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue; EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator. /// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>; using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
}; };
@ -228,8 +223,6 @@ template <
typename EpilogueOutputOp1, typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator /// Threadblock-level swizzling operator
typename ThreadblockSwizzle, typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator typename Operator
> >
@ -249,7 +242,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1, EpilogueOutputOp1,
ThreadblockSwizzle, ThreadblockSwizzle,
2, 2,
SplitKSerial,
Operator Operator
> { > {
@ -287,7 +279,7 @@ struct DefaultB2bGemm<
>::Epilogue; >::Epilogue;
/// Define the kernel-level GEMM operator. /// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>; using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
}; };
@ -323,9 +315,6 @@ template <
int Stages, int Stages,
/// Number of Interleaved k /// Number of Interleaved k
int InterleavedK, int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator> typename Operator>
struct DefaultB2bGemm< struct DefaultB2bGemm<
@ -335,8 +324,7 @@ struct DefaultB2bGemm<
arch::OpClassTensorOp, arch::Sm80, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages, ThreadblockSwizzle, Stages, Operator> {
SplitKSerial, Operator> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>; using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>; using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>; using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -360,7 +348,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue; 64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator. /// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>; using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -396,9 +384,6 @@ template <
typename ThreadblockSwizzle, typename ThreadblockSwizzle,
/// Number of Interleaved k /// Number of Interleaved k
int InterleavedK, int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator> typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>, struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
@ -408,7 +393,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75, int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator> { ThreadblockSwizzle, 2, Operator> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>; using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>; using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>; using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -430,7 +415,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue; 64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator. /// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>; using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -112,16 +112,13 @@ template <
typename ThreadblockSwizzle, typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop /// Number of stages used in the pipelined mainloop
int Stages, int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator> typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1, arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape, WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial, EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, true> { Operator, true> {
/// Define the threadblock-scoped matrix multiply-accumulate /// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue; EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator. /// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>; using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Architecture /// Partial specialization for Turing Architecture
@ -179,8 +175,6 @@ template <
typename EpilogueOutputOp1, typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator /// Threadblock-level swizzling operator
typename ThreadblockSwizzle, typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator typename Operator
> >
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1, EpilogueOutputOp1,
ThreadblockSwizzle, ThreadblockSwizzle,
2, 2,
SplitKSerial,
Operator, Operator,
true true
> { > {
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
>::Epilogue; >::Epilogue;
/// Define the kernel-level GEMM operator. /// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>; using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
}; };
@ -277,9 +270,6 @@ template <
int Stages, int Stages,
/// Number of Interleaved k /// Number of Interleaved k
int InterleavedK, int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator> typename Operator>
struct DefaultB2bGemm< struct DefaultB2bGemm<
@ -290,7 +280,7 @@ struct DefaultB2bGemm<
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages, ThreadblockSwizzle, Stages,
SplitKSerial, Operator, true> { Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>; using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>; using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>; using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue; 64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator. /// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>; using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -350,9 +340,6 @@ template <
typename ThreadblockSwizzle, typename ThreadblockSwizzle,
/// Number of Interleaved k /// Number of Interleaved k
int InterleavedK, int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM /// Operation performed by GEMM
typename Operator> typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>, struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
@ -362,7 +349,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75, int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> { ThreadblockSwizzle, 2, Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>; using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>; using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>; using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue; 64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator. /// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>; using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
} }
} }
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
typename ScalarType, ///< alpha Type
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
int kMblock = 4,
int kNblock = 4
>
__global__ void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRefIn tensor_in, ///< input tensor
TensorRefOut tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
ConvertOp convert_op;
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
int batch_idx = blockIdx.z;
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kNblock; j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMblock; i++) {
int row = row_block + i;
int col = col_block + j;
MatrixCoord coord = MatrixCoord(row, col);
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, coord.column()});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
scale * ScalarType(tensor_in.at(coord)) + bias);
}
}
}
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
}
}
template < template <
typename TensorRefIn, ///< Input TensorRef Type typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type typename TensorRefOut, ///< Output TensorRef Type
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
); );
} }
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type
typename ElementOut, ///< Output Type
typename Layout, ///< Layout of input/output tensor
typename ScalarType, ///< alpha Type
typename LayoutScaleBias, ///< Layout of scale and bias
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
>
void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
int const kMblock = 4;
int const kNblock = 4;
dim3 block(16, 8);
dim3 grid(
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
batch_count % std::numeric_limits<uint16_t>::max()
);
kernel::TensorScaleBiasGemmBatched<
TensorRef<ElementIn, Layout>,
TensorRef<ElementOut, Layout>,
ScalarType,
TensorRef<ScalarType, LayoutScaleBias>,
ConvertOp,
kMblock,
kNblock
><<< grid, block >>> (
problem_size,
tensor_in,
tensor_out,
alpha,
tensor_scale,
tensor_bias,
batch_count,
batch_stride_tensor_in,
batch_stride_tensor_out,
batch_stride_tensor_scale,
batch_stride_tensor_bias
);
}
/// Apply scale and bias on a tensor /// Apply scale and bias on a tensor
template < template <
typename ElementIn, ///< Input Type typename ElementIn, ///< Input Type