diff --git a/examples/13_two_tensor_op_fusion/README.md b/examples/13_two_tensor_op_fusion/README.md
index 008ed94b..fc2655ea 100644
--- a/examples/13_two_tensor_op_fusion/README.md
+++ b/examples/13_two_tensor_op_fusion/README.md
@@ -1,11 +1,11 @@
# Introduction
-This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
+This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.

-When running two unfused GEMM/Conv operations, each operation loads one input
-activation matrix, one weight matrix (or filter matrix) from the memory and then
+When running two unfused GEMM/Conv operations, each operation loads one input
+activation matrix, one weight matrix (or filter matrix) from the memory and then
stores the result activation matrix back to the memory.
When the two GEMM/Conv operations are fused together, the mainloops of the two
@@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same
threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across
2 GEMMs/Convs.
-In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
+In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
input activation, the example enforces the following two constraints:
-- thread_block_tile_N = problem_N
+- thread_block_tile_N = problem_N

@@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o
2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the
operation can be fully block-resident.
-- warp_tile_N = thread_block_tile_N
+- warp_tile_N = thread_block_tile_N

@@ -82,7 +82,7 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
-
+
# Copyright
diff --git a/examples/13_two_tensor_op_fusion/b2b_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_gemm_run.h
index b8b080cf..39ce488d 100644
--- a/examples/13_two_tensor_op_fusion/b2b_gemm_run.h
+++ b/examples/13_two_tensor_op_fusion/b2b_gemm_run.h
@@ -42,6 +42,7 @@
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
+#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
@@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun
//
B2bNonFusedGemmRun(
- cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
- cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
- cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
@@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun
/// Helper to initialize a tensor view
template
bool initialize_tensor(
- cutlass::TensorView view,
+ cutlass::TensorView view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
- }
+ }
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
@@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun
/// Executes one test
bool run(
- cutlass::gemm::GemmCoord problem_size_0,
- cutlass::gemm::GemmCoord problem_size_1,
- ElementCompute alpha0 = ElementCompute(1),
+ cutlass::gemm::GemmCoord problem_size_0,
+ cutlass::gemm::GemmCoord problem_size_1,
+ ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
- ElementCompute alpha1 = ElementCompute(1),
+ ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
-
+
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
- typename Gemm0::ElementA,
+ typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
- typename Gemm0::ElementB,
+ typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
- typename Gemm0::ElementC,
+ typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
- ElementCompute,
+ ElementCompute,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
- typename Gemm0::ElementC,
+ typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
- typename Gemm0::ElementC,
+ typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
- typename Gemm1::ElementB,
+ typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
- typename Gemm1::ElementC,
+ typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
- ElementCompute,
+ ElementCompute,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor<
- typename Gemm1::ElementC,
+ typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
- typename Gemm1::ElementC,
+ typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
@@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
-
+
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
-
+
CUTLASS_CHECK(status);
}
@@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun
reference_gemm_0(
problem_size_0,
- alpha0,
- tensor_A0.device_ref(),
- tensor_B0.device_ref(),
- beta0,
+ alpha0,
+ tensor_A0.device_ref(),
+ tensor_B0.device_ref(),
+ beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
- cutlass::reference::device::TensorReLu(reference_D0.device_view());
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
- alpha1,
- reference_D0.device_ref(),
- tensor_B1.device_ref(),
+ alpha1,
+ reference_D0.device_ref(),
+ tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
-
+
if(relu) {
- cutlass::reference::device::TensorReLu(reference_D1.device_view());
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
-
+
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
@@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
- reference_D1.host_view(),
+ reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun
std::ofstream file(fname.str());
- file
+ file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
@@ -399,9 +400,9 @@ struct B2bFusedGemmRun
//
B2bFusedGemmRun(
- cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
- cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
- cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
@@ -412,7 +413,7 @@ struct B2bFusedGemmRun
/// Helper to initialize a tensor view
template
bool initialize_tensor(
- cutlass::TensorView view,
+ cutlass::TensorView view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@@ -420,11 +421,11 @@ struct B2bFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
- }
+ }
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
- }
+ }
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@@ -453,70 +454,90 @@ struct B2bFusedGemmRun
/// Executes one test
bool run(
- cutlass::gemm::GemmCoord problem_size_0,
- cutlass::gemm::GemmCoord problem_size_1,
- ElementCompute alpha0 = ElementCompute(1),
+ cutlass::gemm::GemmCoord problem_size_0,
+ cutlass::gemm::GemmCoord problem_size_1,
+ ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
- ElementCompute alpha1 = ElementCompute(1),
+ ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
+ cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
+
+ // batch_count is used as split-k when mode is kGemm according
+ // to the GemmUniversal interface
+
+ int batch_count = 1,
+ int64_t batch_stride_A0 = 0,
+ int64_t batch_stride_B0 = 0,
+ int64_t batch_stride_C0 = 0,
+ int64_t batch_stride_B1 = 0,
+ int64_t batch_stride_C1 = 0,
+ int64_t batch_stride_D1 = 0,
+ int64_t batch_stride_Bias0 = 0,
+ int64_t batch_stride_Scale0 = 0,
bool relu = true,
int warm_ups = 1,
int runs = 100) {
-
+
//
// Allocate the GEMM workspace
//
- cutlass::HostTensor<
- typename B2bGemm::ElementA,
- typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
+ cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
+ cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
+ cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
+ cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
+ cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor<
- typename B2bGemm::ElementB,
- typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
+ typename B2bGemm::ElementA,
+ typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor<
- typename B2bGemm::ElementScaleBias,
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
- tensor_Scale0.resize({1, problem_size_0.n()});
+ tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
- typename B2bGemm::ElementScaleBias,
- typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
+ typename B2bGemm::ElementScaleBias,
+ typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
- ElementAccumulator,
- typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
+ ElementAccumulator,
+ typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor<
- typename B2bGemm::ElementB,
- typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
- ElementCompute,
- typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@@ -554,6 +575,7 @@ struct B2bFusedGemmRun
//
typename B2bGemm::Arguments arguments{
+ mode,
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
@@ -564,8 +586,16 @@ struct B2bFusedGemmRun
tensor_B1.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(),
+ batch_stride_A0,
+ batch_stride_B0,
+ batch_stride_B1,
+ batch_stride_C1,
+ batch_stride_D1,
+ batch_stride_Bias0,
+ batch_stride_Scale0,
{alpha0, beta0},
{alpha1, beta1},
+ batch_count,
};
B2bGemm b2b_gemm_op;
@@ -618,32 +648,31 @@ struct B2bFusedGemmRun
// Verify
//
- cutlass::reference::device::Gemm<
- typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
- typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
- ElementAccumulator, typename B2bGemm::LayoutC,
- ElementAccumulator, ElementAccumulator>
- reference_gemm_0;
+ cutlass::reference::device::GemmComplex<
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
+ ElementAccumulator, typename B2bGemm::LayoutC,
+ ElementAccumulator, ElementAccumulator
+ >(
- cutlass::reference::device::Gemm<
- typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
- typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
- typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
- ElementAccumulator, typename B2bGemm::Operator>
- reference_gemm_1;
-
- reference_gemm_0(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
- tensor_A0.device_ref(),
- tensor_B0.device_ref(),
+ tensor_A0.device_ref(),
+ cutlass::ComplexTransform::kNone,
+ tensor_B0.device_ref(),
+ cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(),
reference_Z0.device_ref(),
- ElementAccumulator(0)
+ ElementAccumulator(0),
+ int(batch_count),
+ batch_stride_A0,
+ batch_stride_B0,
+ batch_stride_C0,
+ batch_stride_C0
);
- cutlass::reference::device::TensorScaleBiasGemm<
+ cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
@@ -652,25 +681,45 @@ struct B2bFusedGemmRun
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
- tensor_Bias0.device_ref()
+ tensor_Bias0.device_ref(),
+ int(batch_count),
+ batch_stride_C0,
+ batch_stride_C0,
+ batch_stride_Scale0,
+ batch_stride_Bias0
);
if(relu) {
- cutlass::reference::device::TensorReLu(reference_D0.device_view());
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
- reference_gemm_1(
+ cutlass::reference::device::GemmComplex<
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
+ ElementCompute, ElementAccumulator
+ >(
problem_size_1,
- alpha1,
- reference_D0.device_ref(),
- tensor_B1.device_ref(),
- beta1,
+ alpha1, //intermediate alpha=1
+ reference_D0.device_ref(),
+ cutlass::ComplexTransform::kNone,
+ tensor_B1.device_ref(),
+ cutlass::ComplexTransform::kNone,
+ beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
- reference_D1.device_ref()
+ reference_D1.device_ref(),
+ ElementAccumulator(0),
+ int(batch_count),
+ batch_stride_C0,
+ batch_stride_B1,
+ batch_stride_C1,
+ batch_stride_D1
);
+
if(relu) {
- cutlass::reference::device::TensorReLu(reference_D1.device_view());
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
+
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
@@ -680,7 +729,7 @@ struct B2bFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
- reference_D1.host_view(),
+ reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@@ -694,7 +743,7 @@ struct B2bFusedGemmRun
std::ofstream file(fname.str());
- file
+ file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
diff --git a/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h
index 51ff1bb7..c649fc36 100644
--- a/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h
+++ b/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h
@@ -43,6 +43,7 @@
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/reference/device/gemm.h"
+#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
@@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun
//
B2bInterleavedNonFusedGemmRun(
- cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
- cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
- cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
@@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun
/// Helper to initialize a tensor view
template
bool initialize_tensor(
- cutlass::TensorView view,
+ cutlass::TensorView view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
- }
+ }
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
@@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun
/// Executes one test
bool run(
- cutlass::gemm::GemmCoord problem_size_0,
- cutlass::gemm::GemmCoord problem_size_1,
- ElementCompute alpha0 = ElementCompute(1),
+ cutlass::gemm::GemmCoord problem_size_0,
+ cutlass::gemm::GemmCoord problem_size_1,
+ ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
- ElementCompute alpha1 = ElementCompute(1),
+ ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true,
int warm_ups = 1,
int runs = 100) {
-
+
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
- typename Gemm0::ElementA,
+ typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
- typename Gemm0::ElementB,
+ typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
- typename Gemm0::ElementB,
+ typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
cutlass::HostTensor<
- typename Gemm0::ElementC,
+ typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
- typename Gemm0::ElementC,
+ typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
- typename Gemm0::ElementC,
+ typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
- typename Gemm0::ElementC,
+ typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
- typename Gemm1::ElementB,
+ typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
- typename Gemm1::ElementB,
+ typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
cutlass::HostTensor<
- typename Gemm1::ElementC,
+ typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
- typename Gemm0::ElementC,
+ typename Gemm0::ElementC,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor<
- typename Gemm1::ElementC,
+ typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
- typename Gemm1::ElementC,
+ typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
-
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
@@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun
for(int i = 0; i < runs; i++) {
status = gemm_op_0();
-
+
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) {
status = gemm_op_1();
-
+
CUTLASS_CHECK(status);
}
@@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun
reference_gemm_0(
problem_size_0,
- alpha0,
- tensor_A0.device_ref(),
- tensor_B0.device_ref(),
- beta0,
+ alpha0,
+ tensor_A0.device_ref(),
+ tensor_B0.device_ref(),
+ beta0,
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref()
);
if(relu) {
- cutlass::reference::device::TensorReLu(reference_D0.device_view());
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
- alpha1,
- reference_D0.device_ref(),
- tensor_B1.device_ref(),
+ alpha1,
+ reference_D0.device_ref(),
+ tensor_B1.device_ref(),
beta1,
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref()
);
-
+
if(relu) {
- cutlass::reference::device::TensorReLu(reference_D1.device_view());
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
- reference_D0.sync_host();
- reference_D1.sync_host();
+ reference_D0.sync_host();
+ reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
@@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
- reference_D1.host_view(),
+ reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun
std::ofstream file(fname.str());
- file
+ file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
@@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun
//
B2bInterleavedFusedGemmRun(
- cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
- cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
- cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
@@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun
/// Helper to initialize a tensor view
template
bool initialize_tensor(
- cutlass::TensorView view,
+ cutlass::TensorView view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
@@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
- }
+ }
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
- }
+ }
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun
/// Executes one test
bool run(
- cutlass::gemm::GemmCoord problem_size_0,
- cutlass::gemm::GemmCoord problem_size_1,
- ElementCompute alpha0 = ElementCompute(1),
+ cutlass::gemm::GemmCoord problem_size_0,
+ cutlass::gemm::GemmCoord problem_size_1,
+ ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
- ElementCompute alpha1 = ElementCompute(1),
+ ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
+ cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
+
+ // batch_count is used as split-k when mode is kGemm according
+ // to the GemmUniversal interface
+
+ int batch_count = 1,
+
+ int64_t batch_stride_A0 = 0,
+ int64_t batch_stride_B0 = 0,
+ int64_t batch_stride_C0 = 0,
+ int64_t batch_stride_B1 = 0,
+ int64_t batch_stride_C1 = 0,
+ int64_t batch_stride_D1 = 0,
+ int64_t batch_stride_Bias0 = 0,
+ int64_t batch_stride_Scale0 = 0,
bool relu = true,
int warm_ups = 1,
int runs = 100) {
-
+
//
// Allocate the GEMM workspace
//
- cutlass::HostTensor<
- typename B2bGemm::ElementA,
- typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
+ cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
+ cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
+ cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
+ cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
+ cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
cutlass::HostTensor<
- typename B2bGemm::ElementB,
- typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
+ typename B2bGemm::ElementA,
+ typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
cutlass::HostTensor<
- typename B2bGemm::ElementB,
- typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
cutlass::HostTensor<
- typename B2bGemm::ElementScaleBias,
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
- tensor_Scale0.resize({1, problem_size_0.n()});
+ tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
- typename B2bGemm::ElementScaleBias,
- typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
+ typename B2bGemm::ElementScaleBias,
+ typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
cutlass::HostTensor<
- ElementAccumulator,
- typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
+ ElementAccumulator,
+ typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
cutlass::HostTensor<
- typename B2bGemm::ElementB,
- typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
cutlass::HostTensor<
- typename B2bGemm::ElementB,
- typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
cutlass::HostTensor<
- typename B2bGemm::ElementC,
- typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
@@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
//Reorder B0
cutlass::reorder_column<16>(
- tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
cutlass::reorder_column(
- tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
@@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
tensor_D1.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
+ // tensor_Bias0_batched.sync_device();
//
// Initialize the GEMM operator
//
typename B2bGemm::Arguments arguments{
+ mode,
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
@@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
tensor_B1_reordered.device_ref(),
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(),
+ batch_stride_A0,
+ batch_stride_B0,
+ batch_stride_B1,
+ batch_stride_C1,
+ batch_stride_D1,
+ batch_stride_Bias0,
+ batch_stride_Scale0,
{alpha0, beta0},
{alpha1, beta1},
+ batch_count,
};
B2bGemm b2b_gemm_op;
@@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
// Verify
//
- cutlass::reference::device::Gemm<
- typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
- typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
- ElementAccumulator, typename B2bGemm::LayoutC,
- ElementAccumulator, ElementAccumulator>
- reference_gemm_0;
-
- cutlass::reference::device::Gemm<
- typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
- typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
- typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
- ElementAccumulator, typename B2bGemm::Operator>
- reference_gemm_1;
-
- reference_gemm_0(
+ cutlass::reference::device::GemmComplex<
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
+ ElementAccumulator, typename B2bGemm::LayoutC,
+ ElementAccumulator, ElementAccumulator
+ >(
problem_size_0,
ElementAccumulator(1), //intermediate alpha=1
- tensor_A0.device_ref(),
- tensor_B0.device_ref(),
+ tensor_A0.device_ref(),
+ cutlass::ComplexTransform::kNone,
+ tensor_B0.device_ref(),
+ cutlass::ComplexTransform::kNone,
ElementAccumulator(0), //beta = 0
reference_Z0.device_ref(),
reference_Z0.device_ref(),
- ElementAccumulator(0)
+ ElementAccumulator(0),
+ int(batch_count),
+ batch_stride_A0,
+ batch_stride_B0,
+ batch_stride_C0,
+ batch_stride_C0
);
- cutlass::reference::device::TensorScaleBiasGemm<
+ cutlass::reference::device::TensorScaleBiasGemmBatched<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
@@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
- tensor_Bias0.device_ref()
+ tensor_Bias0.device_ref(),
+ int(batch_count),
+ batch_stride_C0,
+ batch_stride_C0,
+ batch_stride_Scale0,
+ batch_stride_Bias0
);
if(relu) {
- cutlass::reference::device::TensorReLu(reference_D0.device_view());
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
- reference_gemm_1(
+ cutlass::reference::device::GemmComplex<
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
+ ElementCompute, ElementAccumulator
+ >(
problem_size_1,
- alpha1,
- reference_D0.device_ref(),
- tensor_B1.device_ref(),
- beta1,
+ alpha1, //intermediate alpha=1
+ reference_D0.device_ref(),
+ cutlass::ComplexTransform::kNone,
+ tensor_B1.device_ref(),
+ cutlass::ComplexTransform::kNone,
+ beta1, //beta = 0
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
- reference_D1.device_ref()
+ reference_D1.device_ref(),
+ ElementAccumulator(0),
+ int(batch_count),
+ batch_stride_C0,
+ batch_stride_B1,
+ batch_stride_C1,
+ batch_stride_D1
);
+
if(relu) {
- cutlass::reference::device::TensorReLu(reference_D1.device_view());
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
+
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
@@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
- reference_D1.host_view(),
+ reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
@@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun
std::ofstream file(fname.str());
- file
+ file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
diff --git a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h
index f365b236..0fbc930b 100644
--- a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h
+++ b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h
@@ -119,8 +119,6 @@ template <
int AlignmentB =
DefaultGemmConfiguration::kAlignmentB,
- /// If true, kernel supports split-K with serial reduction
- bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
@@ -154,7 +152,6 @@ class B2bGemm {
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp1::kCount;
- static bool const kSplitKSerial = SplitKSerial;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
@@ -184,7 +181,6 @@ class B2bGemm {
EpilogueOutputOp1,
ThreadblockSwizzle,
kStages,
- kSplitKSerial,
Operator,
SmemAccumulator
>::B2bGemmKernel;
@@ -196,6 +192,7 @@ class B2bGemm {
// Data members
//
+ GemmUniversalMode mode;
GemmCoord problem_size_0;
GemmCoord problem_size_1;
TensorRef ref_A0;
@@ -206,9 +203,16 @@ class B2bGemm {
TensorRef ref_B1;
TensorRef ref_C1;
TensorRef ref_D1;
+ int64_t batch_stride_A0;
+ int64_t batch_stride_B0;
+ int64_t batch_stride_B1;
+ int64_t batch_stride_C1;
+ int64_t batch_stride_D1;
+ int64_t batch_stride_Bias0;
+ int64_t batch_stride_Scale0;
typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1;
- int split_k_slices;
+ int batch_count;
//
// Methods
@@ -216,13 +220,14 @@ class B2bGemm {
/// Default ctor
CUTLASS_HOST_DEVICE
- Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
+ Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
}
- /// Constructs an Arguments structure
+ /// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
+ GemmUniversalMode mode_,
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
TensorRef ref_A0_,
@@ -233,12 +238,20 @@ class B2bGemm {
TensorRef ref_B1_,
TensorRef ref_C1_,
TensorRef ref_D1_,
- typename EpilogueOutputOp0::Params epilogue0_ =
+ int64_t batch_stride_A0_,
+ int64_t batch_stride_B0_,
+ int64_t batch_stride_B1_,
+ int64_t batch_stride_C1_,
+ int64_t batch_stride_D1_,
+ int64_t batch_stride_Bias0_,
+ int64_t batch_stride_Scale0_,
+ typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(),
- typename EpilogueOutputOp1::Params epilogue1_ =
+ typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(),
- int split_k_slices_ = 1
+ int batch_count_ = 1
):
+ mode(mode_),
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
@@ -249,9 +262,16 @@ class B2bGemm {
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
+ batch_stride_A0(batch_stride_A0_),
+ batch_stride_B0(batch_stride_B0_),
+ batch_stride_B1(batch_stride_B1_),
+ batch_stride_C1(batch_stride_C1_),
+ batch_stride_D1(batch_stride_D1_),
+ batch_stride_Bias0(batch_stride_Bias0_),
+ batch_stride_Scale0(batch_stride_Scale0_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
- split_k_slices(split_k_slices_) {
+ batch_count(batch_count_) {
}
};
@@ -269,10 +289,6 @@ public:
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
- if (!kSplitKSerial && args.split_k_slices > 1) {
- return Status::kErrorInvalidProblem;
- }
-
Status status = B2bGemmKernel::can_implement(
args.problem_size_0,
args.problem_size_1,
@@ -295,20 +311,14 @@ public:
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
-
+
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
- args.problem_size_0,
+ args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
- args.split_k_slices);
-
- if (kSplitKSerial && args.split_k_slices > 1) {
-
-
- bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
- }
+ args.batch_count);
return bytes;
}
@@ -320,38 +330,17 @@ public:
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
- args.problem_size_0,
+ args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
- args.split_k_slices);
+ args.batch_count);
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
-// args.problem_size_1,
+// args.problem_size_1,
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
-// args.split_k_slices);
-
- if (kSplitKSerial) {
- if (args.split_k_slices > 1) {
- if (!workspace) {
- return Status::kErrorWorkspaceNull;
- }
-
- size_t bytes = get_workspace_size(args);
-
- cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
-
- if (result != cudaSuccess) {
- return Status::kErrorInternal;
- }
- }
- }
- else {
-
- if (args.split_k_slices > 1) {
- return Status::kErrorInvalidProblem;
- }
- }
+// args.batch_count);
// Initialize the Params structure
params_ = typename B2bGemmKernel::Params{
+ args.mode,
args.problem_size_0,
args.problem_size_1,
grid_shape,
@@ -363,6 +352,13 @@ public:
args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(),
args.ref_D1,
+ args.batch_stride_A0,
+ args.batch_stride_B0,
+ args.batch_stride_B1,
+ args.batch_stride_C1,
+ args.batch_stride_D1,
+ args.batch_stride_Bias0,
+ args.batch_stride_Scale0,
args.epilogue0,
args.epilogue1,
static_cast(workspace),
@@ -373,12 +369,6 @@ public:
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
-
- if (kSplitKSerial && args.split_k_slices > 1) {
- if (!workspace) {
- return Status::kErrorWorkspaceNull;
- }
- }
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
@@ -430,12 +420,12 @@ public:
/// Runs the kernel using initialized state.
Status operator()(
- Arguments const &args,
- void *workspace = nullptr,
+ Arguments const &args,
+ void *workspace = nullptr,
cudaStream_t stream = nullptr) {
-
+
Status status = initialize(args, workspace, stream);
-
+
if (status == Status::kSuccess) {
status = run(stream);
}
diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu
index 60f9adb1..bf025b92 100644
--- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu
+++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu
@@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
- using EpilogueOutputOp0 =
+ using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
@@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
- using EpilogueOutputOp1 =
+ using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits::value,
@@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
SmemAccumulator,
16,
16,
- false,
cutlass::arch::OpMultiplyAddSaturate
>;
B2bInterleavedFusedGemmRun fusedGemm;
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
- bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
+ bool passed = fusedGemm.run(
+ gemm_s8_sm80_problem_size_0,
+ gemm_s8_sm80_problem_size_1,
+ alpha0,
+ beta0,
+ alpha1,
+ beta1
+ );
+
if(passed)
std::cout << "Pass\n";
else
@@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
return passed;
}
+bool run_fused_gemm_s8_sm80_rf_res_batch() {
+
+
+ cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
+ cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
+
+ using ElementOutput = int8_t;
+ using ElementAccumulator = int32_t;
+ using ElementCompute = float;
+
+ ElementCompute alpha0 = ElementCompute(1);
+ //Fused kernel has built-in bias, setting beta=0
+ ElementCompute beta0 = ElementCompute(0);
+ ElementCompute alpha1 = ElementCompute(1);
+ ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
+
+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
+ using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
+
+ using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
+
+ using EpilogueOutputOp0 =
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ 8 * InstructionShape::kN / 32,
+ ElementAccumulator,
+ ElementCompute,
+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
+ >;
+
+ using EpilogueOutputOp1 =
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ 64 / cutlass::sizeof_bits::value,
+ ElementAccumulator,
+ ElementCompute,
+ cutlass::epilogue::thread::ScaleType::NoBetaScaling
+ >;
+
+ const bool SmemAccumulator = false;
+
+ using B2bGemm = cutlass::gemm::device::B2bGemm<
+ int8_t,
+ cutlass::layout::ColumnMajorInterleaved<32>,
+ int8_t,
+ cutlass::layout::RowMajorInterleaved<32>,
+ ElementOutput,
+ cutlass::layout::ColumnMajorInterleaved<32>,
+ ElementAccumulator,
+ cutlass::arch::OpClassTensorOp,
+ cutlass::arch::Sm80,
+ ThreadblockShape0,
+ ThreadblockShape1,
+ WarpShape0,
+ WarpShape1,
+ InstructionShape,
+ EpilogueOutputOp0,
+ EpilogueOutputOp1,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
+ 3,
+ SmemAccumulator,
+ 16,
+ 16,
+ cutlass::arch::OpMultiplyAddSaturate
+ >;
+
+ B2bInterleavedFusedGemmRun fusedGemm;
+
+ int batch_count = 2;
+ int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
+ int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
+ int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
+ int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
+ int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
+ int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
+ int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
+ int64_t batch_stride_Scale0 = 0;
+
+ std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
+ bool passed = fusedGemm.run(
+ gemm_s8_sm80_problem_size_0,
+ gemm_s8_sm80_problem_size_1,
+ alpha0,
+ beta0,
+ alpha1,
+ beta1,
+ cutlass::gemm::GemmUniversalMode::kBatched,
+ batch_count,
+ batch_stride_A0,
+ batch_stride_B0,
+ batch_stride_C0,
+ batch_stride_B1,
+ batch_stride_C1,
+ batch_stride_D1,
+ batch_stride_Bias0,
+ batch_stride_Scale0
+ );
+
+ if(passed)
+ std::cout << "Pass\n";
+ else
+ std::cout << "Fail\n";
+
+ return passed;
+}
int main() {
std::vectorfuncs = {
&run_nonfused_gemm_s8_sm80,
- &run_fused_gemm_s8_sm80_rf_res
+ &run_fused_gemm_s8_sm80_rf_res,
+ &run_fused_gemm_s8_sm80_rf_res_batch
};
return testRun(80, funcs, "gemm int8 RF residency");
-
-
}
-
////////////////////////////////////////////////////////////////////////////////
diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu
index 64788e05..8a284fd0 100644
--- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu
+++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu
@@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
- using EpilogueOutputOp0 =
+ using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
8 * InstructionShape::kN / 32,
@@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
- using EpilogueOutputOp1 =
+ using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits::value,
@@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
-
+
const bool SmemAccumulator = true;
using B2bGemm = cutlass::gemm::device::B2bGemm<
@@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
SmemAccumulator,
16,
16,
- false,
cutlass::arch::OpMultiplyAddSaturate
>;
diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h
index 1ccf902e..f4794c57 100644
--- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h
+++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h
@@ -49,10 +49,9 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
- typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
+ typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
- typename ThreadblockSwizzle_, ///! Threadblock swizzling function
- bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
+ typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct B2bGemm {
@@ -61,7 +60,17 @@ struct B2bGemm {
using OutputOp0 = typename B2bMma::OutputOp;
using OutputOp1 = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
- static bool const kSplitKSerial = SplitKSerial;
+
+ using ElementA0 = typename B2bMma::IteratorA0::Element;
+ using LayoutA0 = typename B2bMma::IteratorA0::Layout;
+ using ElementB0 = typename B2bMma::IteratorB0::Element;
+ using LayoutB0 = typename B2bMma::IteratorB0::Layout;
+ using ElementB1 = typename B2bMma::IteratorB1::Element;
+ using LayoutB1 = typename B2bMma::IteratorB1::Layout;
+ using ElementC = typename Epilogue::OutputTileIterator::Element;
+ using LayoutC = typename Epilogue::OutputTileIterator::Layout;
+
+ using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
/// Warp count (concept: GemmShape)
using WarpCount0 = typename B2bMma::WarpCount0;
@@ -69,6 +78,7 @@ struct B2bGemm {
/// Parameters structure
struct Params {
+ cutlass::gemm::GemmUniversalMode mode;
cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape;
@@ -89,6 +99,13 @@ struct B2bGemm {
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
+ int64_t batch_stride_A0;
+ int64_t batch_stride_B0;
+ int64_t batch_stride_B1;
+ int64_t batch_stride_C1;
+ int64_t batch_stride_D1;
+ int64_t batch_stride_Bias0;
+ int64_t batch_stride_Scale0;
int *semaphore;
int gemm_k_iterations_0;
int gemm_k_size_0;
@@ -100,11 +117,12 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
- Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
+ Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
CUTLASS_HOST_DEVICE
Params(
+ cutlass::gemm::GemmUniversalMode mode,
cutlass::gemm::GemmCoord const & problem_size_0,
cutlass::gemm::GemmCoord const & problem_size_1,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
@@ -116,10 +134,18 @@ struct B2bGemm {
typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
+ int64_t batch_stride_A0,
+ int64_t batch_stride_B0,
+ int64_t batch_stride_B1,
+ int64_t batch_stride_C1,
+ int64_t batch_stride_D1,
+ int64_t batch_stride_Bias0,
+ int64_t batch_stride_Scale0,
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
int *workspace = nullptr
):
+ mode(mode),
problem_size_0(problem_size_0),
problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape),
@@ -138,6 +164,13 @@ struct B2bGemm {
ref_C1(ref_C1),
params_D1(ref_D1.layout()),
ref_D1(ref_D1),
+ batch_stride_A0(batch_stride_A0),
+ batch_stride_B0(batch_stride_B0),
+ batch_stride_B1(batch_stride_B1),
+ batch_stride_C1(batch_stride_C1),
+ batch_stride_D1(batch_stride_D1),
+ batch_stride_Bias0(batch_stride_Bias0),
+ batch_stride_Scale0(batch_stride_Scale0),
output_op_0(output_op_0),
output_op_1(output_op_1) {
@@ -163,7 +196,7 @@ struct B2bGemm {
//
CUTLASS_HOST_DEVICE
- B2bGemm() { }
+ B2bGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
@@ -223,7 +256,7 @@ struct B2bGemm {
if(problem_size_0.n() > B2bMma::Shape0::kN)
return Status::kErrorInvalidProblem;
-
+
if(problem_size_1.n() > B2bMma::Shape1::kN)
return Status::kErrorInvalidProblem;
@@ -247,37 +280,64 @@ struct B2bGemm {
return;
}
+ ElementA0 *ptr_A0 = static_cast(params.ref_A0.data());
+ ElementB0 *ptr_B0 = static_cast(params.ref_B0.data());
+ ElementB1 *ptr_B1 = static_cast(params.ref_B1.data());
+
+ ScaleBiasData *ptr_Bias0 = static_cast(params.ref_Bias0.data());
+ ScaleBiasData *ptr_Scale0 = static_cast(params.ref_Scale0.data());
+
+ int offset_k_0 = 0;
+ int offset_k_1 = 0;
+
+ int problem_size_k_0 = params.problem_size_0.k();
+ int problem_size_k_1 = params.problem_size_1.k();
+
+ if (params.mode == GemmUniversalMode::kGemm) {
+
+ // Problem size is a function of threadblock index in the K dimension
+ problem_size_k_0 = min(
+ problem_size_k_0,
+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
+
+ // Problem size is a function of threadblock index in the K dimension
+ problem_size_k_1 = min(
+ problem_size_k_1,
+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
+
+ offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
+ offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
+ }
+
+ else if (params.mode == GemmUniversalMode::kBatched) {
+ ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
+ ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
+ ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
+ ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
+ ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
+ }
+
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A0{
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
- threadblock_tile_offset.k() * params.gemm_k_size_0,
+ offset_k_0,
};
cutlass::MatrixCoord tb_offset_B0{
- threadblock_tile_offset.k() * params.gemm_k_size_0,
+ offset_k_0,
threadblock_tile_offset.n() * B2bMma::Shape0::kN
};
cutlass::MatrixCoord tb_offset_B1{
- threadblock_tile_offset.k() * params.gemm_k_size_1,
+ offset_k_1,
threadblock_tile_offset.n() * B2bMma::Shape1::kN
};
- // Problem size is a function of threadblock index in the K dimension
- int problem_size_k_0 = min(
- params.problem_size_0.k(),
- (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
-
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
- // Problem size is a function of threadblock index in the K dimension
- int problem_size_k_1 = min(
- params.problem_size_1.k(),
- (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
-
// Compute threadblock-scoped matrix multiply-add
-// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
+ // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// Compute position within threadblock
@@ -286,26 +346,25 @@ struct B2bGemm {
// Construct iterators to A and B operands
typename B2bMma::IteratorA0 iterator_A0(
params.params_A0,
- params.ref_A0.data(),
+ ptr_A0,
{params.problem_size_0.m(), problem_size_k_0},
thread_idx,
tb_offset_A0);
typename B2bMma::IteratorB0 iterator_B0(
params.params_B0,
- params.ref_B0.data(),
+ ptr_B0,
{problem_size_k_0, params.problem_size_0.n()},
thread_idx,
tb_offset_B0);
typename B2bMma::IteratorB1 iterator_B1(
params.params_B1,
- params.ref_B1.data(),
+ ptr_B1,
{problem_size_k_1, params.problem_size_1.n()},
thread_idx,
tb_offset_B1);
-
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
@@ -313,7 +372,7 @@ struct B2bGemm {
// Construct iterators to accumulator scale/bias vector
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
- params.ref_Scale0.data(),
+ ptr_Scale0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@@ -323,7 +382,7 @@ struct B2bGemm {
);
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
- params.ref_Bias0.data(),
+ ptr_Bias0,
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
@@ -349,11 +408,9 @@ struct B2bGemm {
src_accum.clear();
accumulators.clear();
- if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
- // Compute threadblock-scoped matrix multiply-add
- b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
- iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
- }
+ // Compute threadblock-scoped matrix multiply-add
+ b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
+ iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
//
// Epilogue
@@ -376,23 +433,32 @@ struct B2bGemm {
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
+ ElementC *ptr_C1 = static_cast(params.ref_C1.data());
+ ElementC *ptr_D1 = static_cast(params.ref_D1.data());
+
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
- // If performing a reduction via split-K, fetch the initial synchronization
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
-
- // Fetch the synchronization lock initially but do not block.
- semaphore.fetch();
+ if (params.mode == GemmUniversalMode::kGemm) {
+ // If performing a reduction via split-K, fetch the initial synchronization
- // Indicate which position in a serial reduction the output operator is currently updating
- output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
+ if (params.grid_tiled_shape.k() > 1) {
+ // Fetch the synchronization lock initially but do not block.
+ semaphore.fetch();
+
+ // Indicate which position in a serial reduction the output operator is currently updating
+ output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
+ }
+ }
+ else if (params.mode == GemmUniversalMode::kBatched) {
+ ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
+ ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
- params.ref_C1.data(),
+ ptr_C1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
@@ -401,21 +467,21 @@ struct B2bGemm {
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D1(
params.params_D1,
- params.ref_D1.data(),
+ ptr_D1,
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
- shared_storage.epilogue,
- thread_idx,
- warp_idx,
+ shared_storage.epilogue,
+ thread_idx,
+ warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
-
+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
+
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C1 = iterator_D1;
@@ -427,14 +493,14 @@ struct B2bGemm {
}
// Execute the epilogue operator to update the destination tensor.
- epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
-
+ epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
+
//
// Release the semaphore
//
- if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
-
+ if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
+
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
@@ -457,4 +523,3 @@ struct B2bGemm {
} // namespace kernel
} // namespace gemm
} // namespace cutlass
-
diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h
index 05c3f4e2..3f54e1da 100644
--- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h
+++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h
@@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
- \brief
+ \brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
-
+
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@@ -114,8 +114,6 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
- /// If true, kernel is configured to support serial reduction in the epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Stage accumulator in shared memory
@@ -161,22 +159,19 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
- /// If true, kernel is configured to support serial reduction in the
- /// epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
- ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@@ -188,7 +183,7 @@ struct DefaultB2bGemm::Epilogue;
/// Define the kernel-level GEMM operator.
- using B2bGemmKernel = kernel::B2bGemm;
+ using B2bGemmKernel = kernel::B2bGemm;
};
@@ -228,8 +223,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
- /// If true, kernel is configured to support serial reduction in the epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@@ -249,7 +242,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
- SplitKSerial,
Operator
> {
@@ -274,7 +266,7 @@ struct DefaultB2bGemm<
Operator,
EpilogueOutputOp0
>::ThreadblockB2bMma;
-
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@@ -287,7 +279,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
- using B2bGemmKernel = kernel::B2bGemm;
+ using B2bGemmKernel = kernel::B2bGemm;
};
@@ -323,20 +315,16 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
- /// If true, kernel is configured to support serial reduction in the
- /// epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved, kAlignmentA,
- ElementB, layout::RowMajorInterleaved, kAlignmentB,
+ ElementB, layout::RowMajorInterleaved, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
- ThreadblockSwizzle, Stages,
- SplitKSerial, Operator> {
+ ThreadblockSwizzle, Stages, Operator> {
using LayoutA = layout::ColumnMajorInterleaved;
using LayoutB = layout::RowMajorInterleaved;
using LayoutC = layout::ColumnMajorInterleaved;
@@ -360,7 +348,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
- using B2bGemmKernel = kernel::B2bGemm;
+ using B2bGemmKernel = kernel::B2bGemm;
};
////////////////////////////////////////////////////////////////////////////////
@@ -396,19 +384,16 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
- /// If true, kernel is configured to support serial reduction in the
- /// epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm,
kAlignmentA, ElementB,
layout::RowMajorInterleaved, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved,
- int32_t, arch::OpClassTensorOp, arch::Sm75,
+ int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
- ThreadblockSwizzle, 2, SplitKSerial, Operator> {
+ ThreadblockSwizzle, 2, Operator> {
using LayoutA = layout::ColumnMajorInterleaved;
using LayoutB = layout::RowMajorInterleaved;
using LayoutC = layout::ColumnMajorInterleaved;
@@ -418,7 +403,7 @@ struct DefaultB2bGemm,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
- arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@@ -430,7 +415,7 @@ struct DefaultB2bGemm,
64 / sizeof_bits::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
- using B2bGemmKernel = kernel::B2bGemm;
+ using B2bGemmKernel = kernel::B2bGemm;
};
////////////////////////////////////////////////////////////////////////////////
diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h
index 23717c61..6849d6e8 100644
--- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h
+++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h
@@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
- \brief
+ \brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
-
+
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@@ -112,22 +112,19 @@ template <
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
- /// If true, kernel is configured to support serial reduction in the
- /// epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
- ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@@ -139,10 +136,9 @@ struct DefaultB2bGemm::Epilogue;
/// Define the kernel-level GEMM operator.
- using B2bGemmKernel = kernel::B2bGemm;
+ using B2bGemmKernel = kernel::B2bGemm;
};
-
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Architecture
@@ -179,8 +175,6 @@ template <
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
- /// If true, kernel is configured to support serial reduction in the epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
@@ -200,7 +194,6 @@ struct DefaultB2bGemm<
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
- SplitKSerial,
Operator,
true
> {
@@ -228,7 +221,7 @@ struct DefaultB2bGemm<
false,
true
>::ThreadblockB2bMma;
-
+
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
@@ -241,7 +234,7 @@ struct DefaultB2bGemm<
>::Epilogue;
/// Define the kernel-level GEMM operator.
- using B2bGemmKernel = kernel::B2bGemm;
+ using B2bGemmKernel = kernel::B2bGemm;
};
@@ -277,20 +270,17 @@ template <
int Stages,
/// Number of Interleaved k
int InterleavedK,
- /// If true, kernel is configured to support serial reduction in the
- /// epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm<
ElementA, layout::ColumnMajorInterleaved, kAlignmentA,
- ElementB, layout::RowMajorInterleaved, kAlignmentB,
+ ElementB, layout::RowMajorInterleaved, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved, int32_t,
arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, Stages,
- SplitKSerial, Operator, true> {
+ Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved;
using LayoutB = layout::RowMajorInterleaved;
using LayoutC = layout::ColumnMajorInterleaved;
@@ -314,7 +304,7 @@ struct DefaultB2bGemm<
64 / sizeof_bits::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
- using B2bGemmKernel = kernel::B2bGemm;
+ using B2bGemmKernel = kernel::B2bGemm;
};
////////////////////////////////////////////////////////////////////////////////
@@ -350,19 +340,16 @@ template <
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
- /// If true, kernel is configured to support serial reduction in the
- /// epilogue
- bool SplitKSerial,
/// Operation performed by GEMM
typename Operator>
struct DefaultB2bGemm,
kAlignmentA, ElementB,
layout::RowMajorInterleaved, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved,
- int32_t, arch::OpClassTensorOp, arch::Sm75,
+ int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
- ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
+ ThreadblockSwizzle, 2, Operator, true> {
using LayoutA = layout::ColumnMajorInterleaved;
using LayoutB = layout::RowMajorInterleaved;
using LayoutC = layout::ColumnMajorInterleaved;
@@ -371,9 +358,9 @@ struct DefaultB2bGemm,
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
- ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
- ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
- ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
+ ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
@@ -385,7 +372,7 @@ struct DefaultB2bGemm,
64 / sizeof_bits::value, InterleavedK>::Epilogue;
/// Define the kernel-level GEMM operator.
- using B2bGemmKernel = kernel::B2bGemm;
+ using B2bGemmKernel = kernel::B2bGemm;
};
////////////////////////////////////////////////////////////////////////////////
diff --git a/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h b/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h
index eef9d9a1..67de37ed 100644
--- a/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h
+++ b/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h
@@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
-
+
ConvertOp convert_op;
MatrixCoord output_coord(
@@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
ScalarType bias = ScalarType(0);
- if(tensor_bias.good())
+ if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
@@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
}
}
+template <
+ typename TensorRefIn, ///< Input TensorRef Type
+ typename TensorRefOut, ///< Output TensorRef Type
+ typename ScalarType, ///< alpha Type
+ typename TensorRefScalar, ///< Scale/Bias TensorRef Type
+ typename ConvertOp = NumericConverter,
+ int kMblock = 4,
+ int kNblock = 4
+>
+__global__ void TensorScaleBiasGemmBatched(
+ gemm::GemmCoord problem_size,
+ TensorRefIn tensor_in, ///< input tensor
+ TensorRefOut tensor_out, ///< output tensor
+ ScalarType alpha, ///< alpha
+ TensorRefScalar tensor_scale, ///< scale tensor
+ TensorRefScalar tensor_bias, ///< bias tensor
+ int batch_count = 1,
+ int64_t batch_stride_tensor_in = 0,
+ int64_t batch_stride_tensor_out = 0,
+ int64_t batch_stride_tensor_scale = 0,
+ int64_t batch_stride_tensor_bias = 0
+) {
+
+ ConvertOp convert_op;
+ int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
+ int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
+ int batch_idx = blockIdx.z;
+
+ tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
+ tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
+ tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
+ tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
+
+ for (; batch_idx < batch_count; batch_idx += gridDim.z) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int j = 0; j < kNblock; j++) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kMblock; i++) {
+ int row = row_block + i;
+ int col = col_block + j;
+ MatrixCoord coord = MatrixCoord(row, col);
+ if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
+
+ ScalarType scale = alpha;
+ if(tensor_scale.good())
+ scale = tensor_scale.at({0, coord.column()});
+
+ ScalarType bias = ScalarType(0);
+
+ if(tensor_bias.good())
+ bias = tensor_bias.at({0, coord.column()});
+
+ tensor_out.at(coord) = convert_op(
+ scale * ScalarType(tensor_in.at(coord)) + bias);
+ }
+ }
+ }
+ tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
+ tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
+ tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
+ tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
+ }
+}
+
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
@@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
-
+
ConvertOp convert_op;
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
@@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
int64_t npq = npq_start + m;
thread_n[m] = int(npq / PQ);
-
+
int64_t residual = npq % PQ;
thread_p[m] = int(residual / problem_size.Q);
thread_q[m] = int(residual % problem_size.Q);
@@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, thread_k});
-
+
ScalarType bias = ScalarType(0);
- if(tensor_bias.good())
+ if(tensor_bias.good())
bias = tensor_bias.at({0, thread_k});
-
+
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
scale * ScalarType(
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
) + bias);
}
- }
+ }
}
}
@@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
);
}
+/// Apply scale and bias on a tensor
+template <
+ typename ElementIn, ///< Input Type
+ typename ElementOut, ///< Output Type
+ typename Layout, ///< Layout of input/output tensor
+ typename ScalarType, ///< alpha Type
+ typename LayoutScaleBias, ///< Layout of scale and bias
+ typename ConvertOp = NumericConverter
+>
+void TensorScaleBiasGemmBatched(
+ gemm::GemmCoord problem_size,
+ TensorRef tensor_in, ///< input tensor
+ TensorRef tensor_out, ///< output tensor
+ ScalarType alpha, ///< alpha
+ TensorRef tensor_scale, ///< scale tensor
+ TensorRef tensor_bias, ///< bias tensor
+ int batch_count = 1,
+ int64_t batch_stride_tensor_in = 0,
+ int64_t batch_stride_tensor_out = 0,
+ int64_t batch_stride_tensor_scale = 0,
+ int64_t batch_stride_tensor_bias = 0
+) {
+
+ int const kMblock = 4;
+ int const kNblock = 4;
+
+ dim3 block(16, 8);
+ dim3 grid(
+ (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
+ (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
+ batch_count % std::numeric_limits::max()
+ );
+
+ kernel::TensorScaleBiasGemmBatched<
+ TensorRef,
+ TensorRef,
+ ScalarType,
+ TensorRef,
+ ConvertOp,
+ kMblock,
+ kNblock
+ ><<< grid, block >>> (
+ problem_size,
+ tensor_in,
+ tensor_out,
+ alpha,
+ tensor_scale,
+ tensor_bias,
+ batch_count,
+ batch_stride_tensor_in,
+ batch_stride_tensor_out,
+ batch_stride_tensor_scale,
+ batch_stride_tensor_bias
+ );
+}
+
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type