added support of b2b bmm (#849)

* added support of b2b bmm

* fixed arguments and params structures

* added batch_count argument

* removed SplitKSerial and added new test case with b2b bmm

* fixed support of Kbatched and added new test case with batch stride

* added batch support for bias and scale

* make test

* small changes

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Aleksandr Pivovar 2023-04-15 05:20:02 +02:00 committed by GitHub
parent d572cc1aab
commit 4a68cf748e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 765 additions and 409 deletions

View File

@ -1,11 +1,11 @@
# Introduction
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
<p align="center"><img src=/media/images/13_example_fusion.png></p>
When running two unfused GEMM/Conv operations, each operation loads one input
activation matrix, one weight matrix (or filter matrix) from the memory and then
When running two unfused GEMM/Conv operations, each operation loads one input
activation matrix, one weight matrix (or filter matrix) from the memory and then
stores the result activation matrix back to the memory.
When the two GEMM/Conv operations are fused together, the mainloops of the two
@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same
threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across
2 GEMMs/Convs.
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
input activation, the example enforces the following two constraints:
- thread_block_tile_N = problem_N
- thread_block_tile_N = problem_N
<p align="center"><img src=/media/images/13_example_block_resident_fusion.png></p>
@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o
2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the
operation can be fully block-resident.
- warp_tile_N = thread_block_tile_N
- warp_tile_N = thread_block_tile_N
<p align="center"><img src=/media/images/13_example_rf_resident_fusion.png></p>
@ -82,7 +82,7 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
# Copyright

View File

@ -42,6 +42,7 @@
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun
//
B2bNonFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
ElementCompute,
ElementCompute,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
ElementCompute,
ElementCompute,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
@ -399,9 +400,9 @@ struct B2bFusedGemmRun
//
B2bFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
@ -412,7 +413,7 @@ struct B2bFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -420,11 +421,11 @@ struct B2bFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@ -453,70 +454,90 @@ struct B2bFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
// batch_count is used as split-k when mode is kGemm according
// to the GemmUniversal interface
int batch_count = 1,
int64_t batch_stride_A0 = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_C0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C1 = 0,
int64_t batch_stride_D1 = 0,
int64_t batch_stride_Bias0 = 0,
int64_t batch_stride_Scale0 = 0,
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()});
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
ElementCompute,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -554,6 +575,7 @@ struct B2bFusedGemmRun
//
typename B2bGemm::Arguments arguments{
mode,
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
@ -564,8 +586,16 @@ struct B2bFusedGemmRun
tensor_B1.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(),
batch_stride_A0,
batch_stride_B0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0,
{alpha0, beta0},
{alpha1, beta1},
batch_count,
};
B2bGemm b2b_gemm_op;
@ -618,32 +648,31 @@ struct B2bFusedGemmRun
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator
>(
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(),
tensor_B0.device_ref(),
tensor_A0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B0.device_ref(),
cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(),
reference_Z0.device_ref(),
ElementAccumulator(0)
ElementAccumulator(0),
int(batch_count),
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_C0
);
cutlass::reference::device::TensorScaleBiasGemm<
cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
@ -652,25 +681,45 @@ struct B2bFusedGemmRun
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()
tensor_Bias0.device_ref(),
int(batch_count),
batch_stride_C0,
batch_stride_C0,
batch_stride_Scale0,
batch_stride_Bias0
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
alpha1, //intermediate alpha=1
reference_D0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B1.device_ref(),
cutlass::ComplexTransform::kNone,
beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref()
reference_D1.device_ref(),
ElementAccumulator(0),
int(batch_count),
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
@ -680,7 +729,7 @@ struct B2bFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -694,7 +743,7 @@ struct B2bFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()

View File

@ -43,6 +43,7 @@
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun
//
B2bInterleavedNonFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::ElementC,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
reference_D0.sync_host();
reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun
//
B2bInterleavedFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
// batch_count is used as split-k when mode is kGemm according
// to the GemmUniversal interface
int batch_count = 1,
int64_t batch_stride_A0 = 0,
int64_t batch_stride_B0 = 0,
int64_t batch_stride_C0 = 0,
int64_t batch_stride_B1 = 0,
int64_t batch_stride_C1 = 0,
int64_t batch_stride_D1 = 0,
int64_t batch_stride_Bias0 = 0,
int64_t batch_stride_Scale0 = 0,
bool relu = true,
int warm_ups = 1,
int runs = 100) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()});
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
//Reorder B0
cutlass::reorder_column<16>(
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
cutlass::reorder_column<InterleavedK_>(
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
tensor_D1.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
// tensor_Bias0_batched.sync_device();
//
// Initialize the GEMM operator
//
typename B2bGemm::Arguments arguments{
mode,
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
tensor_B1_reordered.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(),
batch_stride_A0,
batch_stride_B0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0,
{alpha0, beta0},
{alpha1, beta1},
batch_count,
};
B2bGemm b2b_gemm_op;
@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_1;
reference_gemm_0(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator
>(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(),
tensor_B0.device_ref(),
tensor_A0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B0.device_ref(),
cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(),
reference_Z0.device_ref(),
ElementAccumulator(0)
ElementAccumulator(0),
int(batch_count),
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_C0
);
cutlass::reference::device::TensorScaleBiasGemm<
cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()
tensor_Bias0.device_ref(),
int(batch_count),
batch_stride_C0,
batch_stride_C0,
batch_stride_Scale0,
batch_stride_Bias0
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
cutlass::reference::device::GemmComplex<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, ElementAccumulator
>(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
alpha1, //intermediate alpha=1
reference_D0.device_ref(),
cutlass::ComplexTransform::kNone,
tensor_B1.device_ref(),
cutlass::ComplexTransform::kNone,
beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref()
reference_D1.device_ref(),
ElementAccumulator(0),
int(batch_count),
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun
std::ofstream file(fname.str());
file
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()

View File

@ -119,8 +119,6 @@ template <
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
@ -154,7 +152,6 @@ class B2bGemm {
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp1::kCount;
static bool const kSplitKSerial = SplitKSerial;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
@ -184,7 +181,6 @@ class B2bGemm {
EpilogueOutputOp1,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator,
SmemAccumulator
>::B2bGemmKernel;
@ -196,6 +192,7 @@ class B2bGemm {
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size_0;
GemmCoord problem_size_1;
TensorRef<ElementA const, LayoutA> ref_A0;
@ -206,9 +203,16 @@ class B2bGemm {
TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1;
int split_k_slices;
int batch_count;
//
// Methods
@ -216,13 +220,14 @@ class B2bGemm {
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
}
/// Constructs an Arguments structure
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmUniversalMode mode_,
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
TensorRef<ElementA const, LayoutA> ref_A0_,
@ -233,12 +238,20 @@ class B2bGemm {
TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_,
typename EpilogueOutputOp0::Params epilogue0_ =
int64_t batch_stride_A0_,
int64_t batch_stride_B0_,
int64_t batch_stride_B1_,
int64_t batch_stride_C1_,
int64_t batch_stride_D1_,
int64_t batch_stride_Bias0_,
int64_t batch_stride_Scale0_,
typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(),
typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(),
int split_k_slices_ = 1
int batch_count_ = 1
):
mode(mode_),
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
@ -249,9 +262,16 @@ class B2bGemm {
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
batch_stride_A0(batch_stride_A0_),
batch_stride_B0(batch_stride_B0_),
batch_stride_B1(batch_stride_B1_),
batch_stride_C1(batch_stride_C1_),
batch_stride_D1(batch_stride_D1_),
batch_stride_Bias0(batch_stride_Bias0_),
batch_stride_Scale0(batch_stride_Scale0_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
split_k_slices(split_k_slices_) {
batch_count(batch_count_) {
}
};
@ -269,10 +289,6 @@ public:
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Status status = B2bGemmKernel::can_implement(
args.problem_size_0,
args.problem_size_1,
@ -295,20 +311,14 @@ public:
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0,
args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices);
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
args.batch_count);
return bytes;
}
@ -320,38 +330,17 @@ public:
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0,
args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices);
args.batch_count);
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
// args.problem_size_1,
// args.problem_size_1,
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
// args.split_k_slices);
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// args.batch_count);
// Initialize the Params structure
params_ = typename B2bGemmKernel::Params{
args.mode,
args.problem_size_0,
args.problem_size_1,
grid_shape,
@ -363,6 +352,13 @@ public:
args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(),
args.ref_D1,
args.batch_stride_A0,
args.batch_stride_B0,
args.batch_stride_B1,
args.batch_stride_C1,
args.batch_stride_D1,
args.batch_stride_Bias0,
args.batch_stride_Scale0,
args.epilogue0,
args.epilogue1,
static_cast<int *>(workspace),
@ -373,12 +369,6 @@ public:
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
@ -430,12 +420,12 @@ public:
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}

View File

@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
SmemAccumulator,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
bool passed = fusedGemm.run(
gemm_s8_sm80_problem_size_0,
gemm_s8_sm80_problem_size_1,
alpha0,
beta0,
alpha1,
beta1
);
if(passed)
std::cout << "Pass\n";
else
@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
return passed;
}
bool run_fused_gemm_s8_sm80_rf_res_batch() {
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
const bool SmemAccumulator = false;
using B2bGemm = cutlass::gemm::device::B2bGemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3,
SmemAccumulator,
16,
16,
cutlass::arch::OpMultiplyAddSaturate
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
int batch_count = 2;
int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
int64_t batch_stride_Scale0 = 0;
std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
bool passed = fusedGemm.run(
gemm_s8_sm80_problem_size_0,
gemm_s8_sm80_problem_size_1,
alpha0,
beta0,
alpha1,
beta1,
cutlass::gemm::GemmUniversalMode::kBatched,
batch_count,
batch_stride_A0,
batch_stride_B0,
batch_stride_C0,
batch_stride_B1,
batch_stride_C1,
batch_stride_D1,
batch_stride_Bias0,
batch_stride_Scale0
);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
return passed;
}
int main() {
std::vector<bool (*)()>funcs = {
&run_nonfused_gemm_s8_sm80,
&run_fused_gemm_s8_sm80_rf_res
&run_fused_gemm_s8_sm80_rf_res,
&run_fused_gemm_s8_sm80_rf_res_batch
};
return testRun(80, funcs, "gemm int8 RF residency");
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 =
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
using EpilogueOutputOp1 =
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
const bool SmemAccumulator = true;
using B2bGemm = cutlass::gemm::device::B2bGemm<
@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
SmemAccumulator,
16,
16,
false,
cutlass::arch::OpMultiplyAddSaturate
>;

View File

@ -49,10 +49,9 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct B2bGemm {
@ -61,7 +60,17 @@ struct B2bGemm {
using OutputOp0 = typename B2bMma::OutputOp;
using OutputOp1 = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA0 = typename B2bMma::IteratorA0::Element;
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
using ElementB0 = typename B2bMma::IteratorB0::Element;
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
using ElementB1 = typename B2bMma::IteratorB1::Element;
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
/// Warp count (concept: GemmShape)
using WarpCount0 = typename B2bMma::WarpCount0;
@ -69,6 +78,7 @@ struct B2bGemm {
/// Parameters structure
struct Params {
cutlass::gemm::GemmUniversalMode mode;
cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape;
@ -89,6 +99,13 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
int64_t batch_stride_A0;
int64_t batch_stride_B0;
int64_t batch_stride_B1;
int64_t batch_stride_C1;
int64_t batch_stride_D1;
int64_t batch_stride_Bias0;
int64_t batch_stride_Scale0;
int *semaphore;
int gemm_k_iterations_0;
int gemm_k_size_0;
@ -100,11 +117,12 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmUniversalMode mode,
cutlass::gemm::GemmCoord const & problem_size_0,
cutlass::gemm::GemmCoord const & problem_size_1,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
@ -116,10 +134,18 @@ struct B2bGemm {
typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
int64_t batch_stride_A0,
int64_t batch_stride_B0,
int64_t batch_stride_B1,
int64_t batch_stride_C1,
int64_t batch_stride_D1,
int64_t batch_stride_Bias0,
int64_t batch_stride_Scale0,
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
int *workspace = nullptr
):
mode(mode),
problem_size_0(problem_size_0),
problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape),
@ -138,6 +164,13 @@ struct B2bGemm {
ref_C1(ref_C1),
params_D1(ref_D1.layout()),
ref_D1(ref_D1),
batch_stride_A0(batch_stride_A0),
batch_stride_B0(batch_stride_B0),
batch_stride_B1(batch_stride_B1),
batch_stride_C1(batch_stride_C1),
batch_stride_D1(batch_stride_D1),
batch_stride_Bias0(batch_stride_Bias0),
batch_stride_Scale0(batch_stride_Scale0),
output_op_0(output_op_0),
output_op_1(output_op_1) {
@ -163,7 +196,7 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
B2bGemm() { }
B2bGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
@ -223,7 +256,7 @@ struct B2bGemm {
if(problem_size_0.n() > B2bMma::Shape0::kN)
return Status::kErrorInvalidProblem;
if(problem_size_1.n() > B2bMma::Shape1::kN)
return Status::kErrorInvalidProblem;
@ -247,37 +280,64 @@ struct B2bGemm {
return;
}
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
int offset_k_0 = 0;
int offset_k_1 = 0;
int problem_size_k_0 = params.problem_size_0.k();
int problem_size_k_1 = params.problem_size_1.k();
if (params.mode == GemmUniversalMode::kGemm) {
// Problem size is a function of threadblock index in the K dimension
problem_size_k_0 = min(
problem_size_k_0,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Problem size is a function of threadblock index in the K dimension
problem_size_k_1 = min(
problem_size_k_1,
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A0{
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
threadblock_tile_offset.k() * params.gemm_k_size_0,
offset_k_0,
};
cutlass::MatrixCoord tb_offset_B0{
threadblock_tile_offset.k() * params.gemm_k_size_0,
offset_k_0,
threadblock_tile_offset.n() * B2bMma::Shape0::kN
};
cutlass::MatrixCoord tb_offset_B1{
threadblock_tile_offset.k() * params.gemm_k_size_1,
offset_k_1,
threadblock_tile_offset.n() * B2bMma::Shape1::kN
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_0 = min(
params.problem_size_0.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_1 = min(
params.problem_size_1.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
// Compute threadblock-scoped matrix multiply-add
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// Compute position within threadblock
@ -286,26 +346,25 @@ struct B2bGemm {
// Construct iterators to A and B operands
typename B2bMma::IteratorA0 iterator_A0(
params.params_A0,
params.ref_A0.data(),
ptr_A0,
{params.problem_size_0.m(), problem_size_k_0},
thread_idx,
tb_offset_A0);
typename B2bMma::IteratorB0 iterator_B0(
params.params_B0,
params.ref_B0.data(),
ptr_B0,
{problem_size_k_0, params.problem_size_0.n()},
thread_idx,
tb_offset_B0);
typename B2bMma::IteratorB1 iterator_B1(
params.params_B1,
params.ref_B1.data(),
ptr_B1,
{problem_size_k_1, params.problem_size_1.n()},
thread_idx,
tb_offset_B1);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
@ -313,7 +372,7 @@ struct B2bGemm {
// Construct iterators to accumulator scale/bias vector
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
params.ref_Scale0.data(),
ptr_Scale0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@ -323,7 +382,7 @@ struct B2bGemm {
);
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
params.ref_Bias0.data(),
ptr_Bias0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@ -349,11 +408,9 @@ struct B2bGemm {
src_accum.clear();
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
}
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
//
// Epilogue
@ -376,23 +433,32 @@ struct B2bGemm {
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
if (params.mode == GemmUniversalMode::kGemm) {
// If performing a reduction via split-K, fetch the initial synchronization
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
if (params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
}
else if (params.mode == GemmUniversalMode::kBatched) {
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
params.ref_C1.data(),
ptr_C1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
@ -401,21 +467,21 @@ struct B2bGemm {
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D1(
params.params_D1,
params.ref_D1.data(),
ptr_D1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C1 = iterator_D1;
@ -427,14 +493,14 @@ struct B2bGemm {
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
@ -457,4 +523,3 @@ struct B2bGemm {
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -114,8 +114,6 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Stage accumulator in shared memory
@ -161,22 +159,19 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -188,7 +183,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -228,8 +223,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@ -249,7 +242,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator
> {
@ -274,7 +266,7 @@ struct DefaultB2bGemm<
Operator,
EpilogueOutputOp0
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@ -287,7 +279,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -323,20 +315,16 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator> {
ThreadblockSwizzle, Stages, Operator> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -360,7 +348,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
@ -396,19 +384,16 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
ThreadblockSwizzle, 2, Operator> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -418,7 +403,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -430,7 +415,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -112,22 +112,19 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
Operator, true> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
EpilogueOutputOp1::kCount>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Architecture
@ -179,8 +175,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator,
true
> {
@ -228,7 +221,7 @@ struct DefaultB2bGemm<
false,
true
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
@ -277,20 +270,17 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
SplitKSerial, Operator, true> {
Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////
@ -350,19 +340,16 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
ThreadblockSwizzle, 2, Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
@ -371,9 +358,9 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
MatrixCoord output_coord(
@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
typename ScalarType, ///< alpha Type
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
int kMblock = 4,
int kNblock = 4
>
__global__ void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRefIn tensor_in, ///< input tensor
TensorRefOut tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
ConvertOp convert_op;
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
int batch_idx = blockIdx.z;
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kNblock; j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMblock; i++) {
int row = row_block + i;
int col = col_block + j;
MatrixCoord coord = MatrixCoord(row, col);
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, coord.column()});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
scale * ScalarType(tensor_in.at(coord)) + bias);
}
}
}
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
int64_t npq = npq_start + m;
thread_n[m] = int(npq / PQ);
int64_t residual = npq % PQ;
thread_p[m] = int(residual / problem_size.Q);
thread_q[m] = int(residual % problem_size.Q);
@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, thread_k});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
if(tensor_bias.good())
bias = tensor_bias.at({0, thread_k});
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
scale * ScalarType(
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
) + bias);
}
}
}
}
}
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type
typename ElementOut, ///< Output Type
typename Layout, ///< Layout of input/output tensor
typename ScalarType, ///< alpha Type
typename LayoutScaleBias, ///< Layout of scale and bias
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
>
void TensorScaleBiasGemmBatched(
gemm::GemmCoord problem_size,
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
int batch_count = 1,
int64_t batch_stride_tensor_in = 0,
int64_t batch_stride_tensor_out = 0,
int64_t batch_stride_tensor_scale = 0,
int64_t batch_stride_tensor_bias = 0
) {
int const kMblock = 4;
int const kNblock = 4;
dim3 block(16, 8);
dim3 grid(
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
batch_count % std::numeric_limits<uint16_t>::max()
);
kernel::TensorScaleBiasGemmBatched<
TensorRef<ElementIn, Layout>,
TensorRef<ElementOut, Layout>,
ScalarType,
TensorRef<ScalarType, LayoutScaleBias>,
ConvertOp,
kMblock,
kNblock
><<< grid, block >>> (
problem_size,
tensor_in,
tensor_out,
alpha,
tensor_scale,
tensor_bias,
batch_count,
batch_stride_tensor_in,
batch_stride_tensor_out,
batch_stride_tensor_scale,
batch_stride_tensor_bias
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type