diff --git a/examples/13_two_tensor_op_fusion/README.md b/examples/13_two_tensor_op_fusion/README.md index 008ed94b..fc2655ea 100644 --- a/examples/13_two_tensor_op_fusion/README.md +++ b/examples/13_two_tensor_op_fusion/README.md @@ -1,11 +1,11 @@ # Introduction -This example shows fusing two back-to-back GEMMs/Convolutions into one kernel. +This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.

-When running two unfused GEMM/Conv operations, each operation loads one input -activation matrix, one weight matrix (or filter matrix) from the memory and then +When running two unfused GEMM/Conv operations, each operation loads one input +activation matrix, one weight matrix (or filter matrix) from the memory and then stores the result activation matrix back to the memory. When the two GEMM/Conv operations are fused together, the mainloops of the two @@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across 2 GEMMs/Convs. -In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the +In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the input activation, the example enforces the following two constraints: -- thread_block_tile_N = problem_N +- thread_block_tile_N = problem_N

@@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o 2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the operation can be fully block-resident. -- warp_tile_N = thread_block_tile_N +- warp_tile_N = thread_block_tile_N

@@ -82,7 +82,7 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem` - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf` - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem` - + # Copyright diff --git a/examples/13_two_tensor_op_fusion/b2b_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_gemm_run.h index b8b080cf..39ce488d 100644 --- a/examples/13_two_tensor_op_fusion/b2b_gemm_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_gemm_run.h @@ -42,6 +42,7 @@ #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_relu.h" #include "reference/device/tensor_scale_bias.h" @@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun // B2bNonFusedGemmRun( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): @@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); @@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size_0, - cutlass::gemm::GemmCoord problem_size_1, - ElementCompute alpha0 = ElementCompute(1), + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), - ElementCompute alpha1 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), bool relu = true, int warm_ups = 1, int runs = 100) { - + // // Allocate the GEMM workspace // cutlass::HostTensor< - typename Gemm0::ElementA, + typename Gemm0::ElementA, typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); cutlass::HostTensor< - typename Gemm0::ElementB, + typename Gemm0::ElementB, typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); cutlass::HostTensor< - ElementCompute, + ElementCompute, typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm1::ElementB, + typename Gemm1::ElementB, typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); cutlass::HostTensor< - ElementCompute, + ElementCompute, typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); @@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun for(int i = 0; i < runs; i++) { status = gemm_op_0(); - + CUTLASS_CHECK(status); } cudaEventRecord(stop1); for(int i = 0; i < runs; i++) { status = gemm_op_1(); - + CUTLASS_CHECK(status); } @@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun reference_gemm_0( problem_size_0, - alpha0, - tensor_A0.device_ref(), - tensor_B0.device_ref(), - beta0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, reference_D0.device_ref() ); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); } reference_gemm_1( problem_size_1, - alpha1, - reference_D0.device_ref(), - tensor_B1.device_ref(), + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), beta1, {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, reference_D1.device_ref() ); - + if(relu) { - cutlass::reference::device::TensorReLu(reference_D1.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); } - + // Wait for kernels to finish cudaDeviceSynchronize(); reference_D0.sync_host(); @@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals( - reference_D1.host_view(), + reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); @@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun std::ofstream file(fname.str()); - file + file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nC0 =\n" << tensor_C0.host_view() @@ -399,9 +400,9 @@ struct B2bFusedGemmRun // B2bFusedGemmRun( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 @@ -412,7 +413,7 @@ struct B2bFusedGemmRun /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -420,11 +421,11 @@ struct B2bFusedGemmRun cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } + } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); @@ -453,70 +454,90 @@ struct B2bFusedGemmRun /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size_0, - cutlass::gemm::GemmCoord problem_size_1, - ElementCompute alpha0 = ElementCompute(1), + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), - ElementCompute alpha1 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm, + + // batch_count is used as split-k when mode is kGemm according + // to the GemmUniversal interface + + int batch_count = 1, + int64_t batch_stride_A0 = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_C0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C1 = 0, + int64_t batch_stride_D1 = 0, + int64_t batch_stride_Bias0 = 0, + int64_t batch_stride_Scale0 = 0, bool relu = true, int warm_ups = 1, int runs = 100) { - + // // Allocate the GEMM workspace // - cutlass::HostTensor< - typename B2bGemm::ElementA, - typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); + cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k()); + cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn()); cutlass::HostTensor< - typename B2bGemm::ElementScaleBias, + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, typename B2bGemm::LayoutScaleBias> tensor_Scale0; if(alpha0 == ElementCompute(0)) //per-channel scale - tensor_Scale0.resize({1, problem_size_0.n()}); + tensor_Scale0.resize({1, batch_count * problem_size_0.n()}); cutlass::HostTensor< - typename B2bGemm::ElementScaleBias, - typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()}); cutlass::HostTensor< - ElementAccumulator, - typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn()); cutlass::HostTensor< - ElementCompute, - typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()}); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(CoordC1.mn()); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); @@ -554,6 +575,7 @@ struct B2bFusedGemmRun // typename B2bGemm::Arguments arguments{ + mode, problem_size_0, problem_size_1, tensor_A0.device_ref(), @@ -564,8 +586,16 @@ struct B2bFusedGemmRun tensor_B1.device_ref(), {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, tensor_D1.device_ref(), + batch_stride_A0, + batch_stride_B0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0, {alpha0, beta0}, {alpha1, beta1}, + batch_count, }; B2bGemm b2b_gemm_op; @@ -618,32 +648,31 @@ struct B2bFusedGemmRun // Verify // - cutlass::reference::device::Gemm< - typename B2bGemm::ElementA, typename B2bGemm::LayoutA, - typename B2bGemm::ElementB, typename B2bGemm::LayoutB, - ElementAccumulator, typename B2bGemm::LayoutC, - ElementAccumulator, ElementAccumulator> - reference_gemm_0; + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( - cutlass::reference::device::Gemm< - typename B2bGemm::ElementA, typename B2bGemm::LayoutA, - typename B2bGemm::ElementB, typename B2bGemm::LayoutB, - typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, - ElementAccumulator, typename B2bGemm::Operator> - reference_gemm_1; - - reference_gemm_0( problem_size_0, ElementAccumulator(1), //intermediate alpha=1 - tensor_A0.device_ref(), - tensor_B0.device_ref(), + tensor_A0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B0.device_ref(), + cutlass::ComplexTransform::kNone, ElementAccumulator(0), //beta = 0 reference_Z0.device_ref(), reference_Z0.device_ref(), - ElementAccumulator(0) + ElementAccumulator(0), + int(batch_count), + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_C0 ); - cutlass::reference::device::TensorScaleBiasGemm< + cutlass::reference::device::TensorScaleBiasGemmBatched< ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, typename B2bGemm::LayoutScaleBias > ( @@ -652,25 +681,45 @@ struct B2bFusedGemmRun reference_D0.device_ref(), alpha0, tensor_Scale0.device_ref(), - tensor_Bias0.device_ref() + tensor_Bias0.device_ref(), + int(batch_count), + batch_stride_C0, + batch_stride_C0, + batch_stride_Scale0, + batch_stride_Bias0 ); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); } - reference_gemm_1( + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, ElementAccumulator + >( problem_size_1, - alpha1, - reference_D0.device_ref(), - tensor_B1.device_ref(), - beta1, + alpha1, //intermediate alpha=1 + reference_D0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B1.device_ref(), + cutlass::ComplexTransform::kNone, + beta1, //beta = 0 {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, - reference_D1.device_ref() + reference_D1.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1 ); + if(relu) { - cutlass::reference::device::TensorReLu(reference_D1.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); } + cudaDeviceSynchronize(); reference_D0.sync_host(); reference_D1.sync_host(); @@ -680,7 +729,7 @@ struct B2bFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals( - reference_D1.host_view(), + reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); @@ -694,7 +743,7 @@ struct B2bFusedGemmRun std::ofstream file(fname.str()); - file + file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nC0 =\n" << tensor_C0.host_view() diff --git a/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h index 51ff1bb7..c649fc36 100644 --- a/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h @@ -43,6 +43,7 @@ #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/host_reorder.h" #include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_relu.h" #include "reference/device/tensor_scale_bias.h" @@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun // B2bInterleavedNonFusedGemmRun( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): @@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); @@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size_0, - cutlass::gemm::GemmCoord problem_size_1, - ElementCompute alpha0 = ElementCompute(1), + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), - ElementCompute alpha1 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), bool relu = true, int warm_ups = 1, int runs = 100) { - + // // Allocate the GEMM workspace // cutlass::HostTensor< - typename Gemm0::ElementA, + typename Gemm0::ElementA, typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); cutlass::HostTensor< - typename Gemm0::ElementB, + typename Gemm0::ElementB, typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); cutlass::HostTensor< - typename Gemm0::ElementB, + typename Gemm0::ElementB, typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm1::ElementB, + typename Gemm1::ElementB, typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); cutlass::HostTensor< - typename Gemm1::ElementB, + typename Gemm1::ElementB, typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn()); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); - CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); @@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun for(int i = 0; i < runs; i++) { status = gemm_op_0(); - + CUTLASS_CHECK(status); } cudaEventRecord(stop1); for(int i = 0; i < runs; i++) { status = gemm_op_1(); - + CUTLASS_CHECK(status); } @@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun reference_gemm_0( problem_size_0, - alpha0, - tensor_A0.device_ref(), - tensor_B0.device_ref(), - beta0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, reference_D0.device_ref() ); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); } reference_gemm_1( problem_size_1, - alpha1, - reference_D0.device_ref(), - tensor_B1.device_ref(), + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), beta1, {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, reference_D1.device_ref() ); - + if(relu) { - cutlass::reference::device::TensorReLu(reference_D1.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); } // Wait for kernels to finish cudaDeviceSynchronize(); - reference_D0.sync_host(); - reference_D1.sync_host(); + reference_D0.sync_host(); + reference_D1.sync_host(); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); @@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals( - reference_D1.host_view(), + reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); @@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun std::ofstream file(fname.str()); - file + file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() @@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun // B2bInterleavedFusedGemmRun( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 @@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } + } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); @@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size_0, - cutlass::gemm::GemmCoord problem_size_1, - ElementCompute alpha0 = ElementCompute(1), + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), - ElementCompute alpha1 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm, + + // batch_count is used as split-k when mode is kGemm according + // to the GemmUniversal interface + + int batch_count = 1, + + int64_t batch_stride_A0 = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_C0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C1 = 0, + int64_t batch_stride_D1 = 0, + int64_t batch_stride_Bias0 = 0, + int64_t batch_stride_Scale0 = 0, bool relu = true, int warm_ups = 1, int runs = 100) { - + // // Allocate the GEMM workspace // - cutlass::HostTensor< - typename B2bGemm::ElementA, - typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); + cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k()); + cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn()); cutlass::HostTensor< - typename B2bGemm::ElementScaleBias, + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, typename B2bGemm::LayoutScaleBias> tensor_Scale0; if(alpha0 == ElementCompute(0)) //per-channel scale - tensor_Scale0.resize({1, problem_size_0.n()}); + tensor_Scale0.resize({1, batch_count * problem_size_0.n()}); cutlass::HostTensor< - typename B2bGemm::ElementScaleBias, - typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()}); cutlass::HostTensor< - ElementAccumulator, - typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()}); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(CoordC1.mn()); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); @@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun //Reorder B0 cutlass::reorder_column<16>( - tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0); cutlass::reorder_column( - tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1); cutlass::reference::host::TensorFill( tensor_D1.host_view()); @@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun tensor_D1.sync_device(); reference_D0.sync_device(); reference_D1.sync_device(); + // tensor_Bias0_batched.sync_device(); // // Initialize the GEMM operator // typename B2bGemm::Arguments arguments{ + mode, problem_size_0, problem_size_1, tensor_A0.device_ref(), @@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun tensor_B1_reordered.device_ref(), {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, tensor_D1.device_ref(), + batch_stride_A0, + batch_stride_B0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0, {alpha0, beta0}, {alpha1, beta1}, + batch_count, }; B2bGemm b2b_gemm_op; @@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun // Verify // - cutlass::reference::device::Gemm< - typename B2bGemm::ElementA, typename B2bGemm::LayoutA, - typename B2bGemm::ElementB, typename B2bGemm::LayoutB, - ElementAccumulator, typename B2bGemm::LayoutC, - ElementAccumulator, ElementAccumulator> - reference_gemm_0; - - cutlass::reference::device::Gemm< - typename B2bGemm::ElementA, typename B2bGemm::LayoutA, - typename B2bGemm::ElementB, typename B2bGemm::LayoutB, - typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, - ElementAccumulator, typename B2bGemm::Operator> - reference_gemm_1; - - reference_gemm_0( + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( problem_size_0, ElementAccumulator(1), //intermediate alpha=1 - tensor_A0.device_ref(), - tensor_B0.device_ref(), + tensor_A0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B0.device_ref(), + cutlass::ComplexTransform::kNone, ElementAccumulator(0), //beta = 0 reference_Z0.device_ref(), reference_Z0.device_ref(), - ElementAccumulator(0) + ElementAccumulator(0), + int(batch_count), + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_C0 ); - cutlass::reference::device::TensorScaleBiasGemm< + cutlass::reference::device::TensorScaleBiasGemmBatched< ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, typename B2bGemm::LayoutScaleBias > ( @@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun reference_D0.device_ref(), alpha0, tensor_Scale0.device_ref(), - tensor_Bias0.device_ref() + tensor_Bias0.device_ref(), + int(batch_count), + batch_stride_C0, + batch_stride_C0, + batch_stride_Scale0, + batch_stride_Bias0 ); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); } - reference_gemm_1( + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, ElementAccumulator + >( problem_size_1, - alpha1, - reference_D0.device_ref(), - tensor_B1.device_ref(), - beta1, + alpha1, //intermediate alpha=1 + reference_D0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B1.device_ref(), + cutlass::ComplexTransform::kNone, + beta1, //beta = 0 {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, - reference_D1.device_ref() + reference_D1.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1 ); + if(relu) { - cutlass::reference::device::TensorReLu(reference_D1.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); } + cudaDeviceSynchronize(); reference_D0.sync_host(); reference_D1.sync_host(); @@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals( - reference_D1.host_view(), + reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); @@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun std::ofstream file(fname.str()); - file + file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() diff --git a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h index f365b236..0fbc930b 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h @@ -119,8 +119,6 @@ template < int AlignmentB = DefaultGemmConfiguration::kAlignmentB, - /// If true, kernel supports split-K with serial reduction - bool SplitKSerial = false, /// Operation performed by GEMM typename Operator_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, @@ -154,7 +152,6 @@ class B2bGemm { static int const kAlignmentA = AlignmentA; static int const kAlignmentB = AlignmentB; static int const kAlignmentC = EpilogueOutputOp1::kCount; - static bool const kSplitKSerial = SplitKSerial; static ComplexTransform const kTransformA = ComplexTransform::kNone; static ComplexTransform const kTransformB = ComplexTransform::kNone; @@ -184,7 +181,6 @@ class B2bGemm { EpilogueOutputOp1, ThreadblockSwizzle, kStages, - kSplitKSerial, Operator, SmemAccumulator >::B2bGemmKernel; @@ -196,6 +192,7 @@ class B2bGemm { // Data members // + GemmUniversalMode mode; GemmCoord problem_size_0; GemmCoord problem_size_1; TensorRef ref_A0; @@ -206,9 +203,16 @@ class B2bGemm { TensorRef ref_B1; TensorRef ref_C1; TensorRef ref_D1; + int64_t batch_stride_A0; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C1; + int64_t batch_stride_D1; + int64_t batch_stride_Bias0; + int64_t batch_stride_Scale0; typename EpilogueOutputOp0::Params epilogue0; typename EpilogueOutputOp1::Params epilogue1; - int split_k_slices; + int batch_count; // // Methods @@ -216,13 +220,14 @@ class B2bGemm { /// Default ctor CUTLASS_HOST_DEVICE - Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) { + Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) { } - /// Constructs an Arguments structure + /// Constructs an Arguments structure CUTLASS_HOST_DEVICE Arguments( + GemmUniversalMode mode_, GemmCoord problem_size_0_, GemmCoord problem_size_1_, TensorRef ref_A0_, @@ -233,12 +238,20 @@ class B2bGemm { TensorRef ref_B1_, TensorRef ref_C1_, TensorRef ref_D1_, - typename EpilogueOutputOp0::Params epilogue0_ = + int64_t batch_stride_A0_, + int64_t batch_stride_B0_, + int64_t batch_stride_B1_, + int64_t batch_stride_C1_, + int64_t batch_stride_D1_, + int64_t batch_stride_Bias0_, + int64_t batch_stride_Scale0_, + typename EpilogueOutputOp0::Params epilogue0_ = typename EpilogueOutputOp0::Params(), - typename EpilogueOutputOp1::Params epilogue1_ = + typename EpilogueOutputOp1::Params epilogue1_ = typename EpilogueOutputOp1::Params(), - int split_k_slices_ = 1 + int batch_count_ = 1 ): + mode(mode_), problem_size_0(problem_size_0_), problem_size_1(problem_size_1_), ref_A0(ref_A0_), @@ -249,9 +262,16 @@ class B2bGemm { ref_B1(ref_B1_), ref_C1(ref_C1_), ref_D1(ref_D1_), + batch_stride_A0(batch_stride_A0_), + batch_stride_B0(batch_stride_B0_), + batch_stride_B1(batch_stride_B1_), + batch_stride_C1(batch_stride_C1_), + batch_stride_D1(batch_stride_D1_), + batch_stride_Bias0(batch_stride_Bias0_), + batch_stride_Scale0(batch_stride_Scale0_), epilogue0(epilogue0_), epilogue1(epilogue1_), - split_k_slices(split_k_slices_) { + batch_count(batch_count_) { } }; @@ -269,10 +289,6 @@ public: /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { - if (!kSplitKSerial && args.split_k_slices > 1) { - return Status::kErrorInvalidProblem; - } - Status status = B2bGemmKernel::can_implement( args.problem_size_0, args.problem_size_1, @@ -295,20 +311,14 @@ public: static size_t get_workspace_size(Arguments const &args) { size_t bytes = 0; - + // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size_0, + args.problem_size_0, {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, - args.split_k_slices); - - if (kSplitKSerial && args.split_k_slices > 1) { - - - bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); - } + args.batch_count); return bytes; } @@ -320,38 +330,17 @@ public: ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size_0, + args.problem_size_0, {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, - args.split_k_slices); + args.batch_count); // cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape( -// args.problem_size_1, +// args.problem_size_1, // {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK}, -// args.split_k_slices); - - if (kSplitKSerial) { - if (args.split_k_slices > 1) { - if (!workspace) { - return Status::kErrorWorkspaceNull; - } - - size_t bytes = get_workspace_size(args); - - cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - } - } - else { - - if (args.split_k_slices > 1) { - return Status::kErrorInvalidProblem; - } - } +// args.batch_count); // Initialize the Params structure params_ = typename B2bGemmKernel::Params{ + args.mode, args.problem_size_0, args.problem_size_1, grid_shape, @@ -363,6 +352,13 @@ public: args.ref_B1.non_const_ref(), args.ref_C1.non_const_ref(), args.ref_D1, + args.batch_stride_A0, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C1, + args.batch_stride_D1, + args.batch_stride_Bias0, + args.batch_stride_Scale0, args.epilogue0, args.epilogue1, static_cast(workspace), @@ -373,12 +369,6 @@ public: /// Lightweight update given a subset of arguments Status update(Arguments const &args, void *workspace = nullptr) { - - if (kSplitKSerial && args.split_k_slices > 1) { - if (!workspace) { - return Status::kErrorWorkspaceNull; - } - } params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); @@ -430,12 +420,12 @@ public: /// Runs the kernel using initialized state. Status operator()( - Arguments const &args, - void *workspace = nullptr, + Arguments const &args, + void *workspace = nullptr, cudaStream_t stream = nullptr) { - + Status status = initialize(args, workspace, stream); - + if (status == Status::kSuccess) { status = run(stream); } diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu index 60f9adb1..bf025b92 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu @@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() { using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - using EpilogueOutputOp0 = + using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, 8 * InstructionShape::kN / 32, @@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() { cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; - using EpilogueOutputOp1 = + using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, 64 / cutlass::sizeof_bits::value, @@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() { SmemAccumulator, 16, 16, - false, cutlass::arch::OpMultiplyAddSaturate >; B2bInterleavedFusedGemmRun fusedGemm; std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n"; - bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); + bool passed = fusedGemm.run( + gemm_s8_sm80_problem_size_0, + gemm_s8_sm80_problem_size_1, + alpha0, + beta0, + alpha1, + beta1 + ); + if(passed) std::cout << "Pass\n"; else @@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() { return passed; } +bool run_fused_gemm_s8_sm80_rf_res_batch() { + + + cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128); + cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64); + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias + + using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; + + using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + using EpilogueOutputOp0 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 8 * InstructionShape::kN / 32, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >; + + using EpilogueOutputOp1 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >; + + const bool SmemAccumulator = false; + + using B2bGemm = cutlass::gemm::device::B2bGemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + SmemAccumulator, + 16, + 16, + cutlass::arch::OpMultiplyAddSaturate + >; + + B2bInterleavedFusedGemmRun fusedGemm; + + int batch_count = 2; + int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k(); + int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n(); + int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n(); + int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n(); + int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n(); + int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n(); + int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n(); + int64_t batch_stride_Scale0 = 0; + + std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n"; + bool passed = fusedGemm.run( + gemm_s8_sm80_problem_size_0, + gemm_s8_sm80_problem_size_1, + alpha0, + beta0, + alpha1, + beta1, + cutlass::gemm::GemmUniversalMode::kBatched, + batch_count, + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0 + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} int main() { std::vectorfuncs = { &run_nonfused_gemm_s8_sm80, - &run_fused_gemm_s8_sm80_rf_res + &run_fused_gemm_s8_sm80_rf_res, + &run_fused_gemm_s8_sm80_rf_res_batch }; return testRun(80, funcs, "gemm int8 RF residency"); - - } - //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu index 64788e05..8a284fd0 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu @@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() { using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - using EpilogueOutputOp0 = + using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, 8 * InstructionShape::kN / 32, @@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() { cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; - using EpilogueOutputOp1 = + using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, 64 / cutlass::sizeof_bits::value, @@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() { ElementCompute, cutlass::epilogue::thread::ScaleType::NoBetaScaling >; - + const bool SmemAccumulator = true; using B2bGemm = cutlass::gemm::device::B2bGemm< @@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() { SmemAccumulator, 16, 16, - false, cutlass::arch::OpMultiplyAddSaturate >; diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h index 1ccf902e..f4794c57 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -49,10 +49,9 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// template < - typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. + typename ThreadblockSwizzle_ ///! Threadblock swizzling function > struct B2bGemm { @@ -61,7 +60,17 @@ struct B2bGemm { using OutputOp0 = typename B2bMma::OutputOp; using OutputOp1 = typename Epilogue::OutputOp; using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; + + using ElementA0 = typename B2bMma::IteratorA0::Element; + using LayoutA0 = typename B2bMma::IteratorA0::Layout; + using ElementB0 = typename B2bMma::IteratorB0::Element; + using LayoutB0 = typename B2bMma::IteratorB0::Layout; + using ElementB1 = typename B2bMma::IteratorB1::Element; + using LayoutB1 = typename B2bMma::IteratorB1::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element; /// Warp count (concept: GemmShape) using WarpCount0 = typename B2bMma::WarpCount0; @@ -69,6 +78,7 @@ struct B2bGemm { /// Parameters structure struct Params { + cutlass::gemm::GemmUniversalMode mode; cutlass::gemm::GemmCoord problem_size_0; cutlass::gemm::GemmCoord problem_size_1; cutlass::gemm::GemmCoord grid_tiled_shape; @@ -89,6 +99,13 @@ struct B2bGemm { typename Epilogue::OutputTileIterator::TensorRef ref_D1; typename OutputOp0::Params output_op_0; typename OutputOp1::Params output_op_1; + int64_t batch_stride_A0; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C1; + int64_t batch_stride_D1; + int64_t batch_stride_Bias0; + int64_t batch_stride_Scale0; int *semaphore; int gemm_k_iterations_0; int gemm_k_size_0; @@ -100,11 +117,12 @@ struct B2bGemm { // CUTLASS_HOST_DEVICE - Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), + Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), gemm_k_iterations_1(0), gemm_k_size_1(0) { } CUTLASS_HOST_DEVICE Params( + cutlass::gemm::GemmUniversalMode mode, cutlass::gemm::GemmCoord const & problem_size_0, cutlass::gemm::GemmCoord const & problem_size_1, cutlass::gemm::GemmCoord const & grid_tiled_shape, @@ -116,10 +134,18 @@ struct B2bGemm { typename B2bMma::IteratorB1::TensorRef ref_B1, typename Epilogue::OutputTileIterator::TensorRef ref_C1, typename Epilogue::OutputTileIterator::TensorRef ref_D1, + int64_t batch_stride_A0, + int64_t batch_stride_B0, + int64_t batch_stride_B1, + int64_t batch_stride_C1, + int64_t batch_stride_D1, + int64_t batch_stride_Bias0, + int64_t batch_stride_Scale0, typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), int *workspace = nullptr ): + mode(mode), problem_size_0(problem_size_0), problem_size_1(problem_size_1), grid_tiled_shape(grid_tiled_shape), @@ -138,6 +164,13 @@ struct B2bGemm { ref_C1(ref_C1), params_D1(ref_D1.layout()), ref_D1(ref_D1), + batch_stride_A0(batch_stride_A0), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C1(batch_stride_C1), + batch_stride_D1(batch_stride_D1), + batch_stride_Bias0(batch_stride_Bias0), + batch_stride_Scale0(batch_stride_Scale0), output_op_0(output_op_0), output_op_1(output_op_1) { @@ -163,7 +196,7 @@ struct B2bGemm { // CUTLASS_HOST_DEVICE - B2bGemm() { } + B2bGemm() { } /// Determines whether kernel satisfies alignment static Status can_implement( @@ -223,7 +256,7 @@ struct B2bGemm { if(problem_size_0.n() > B2bMma::Shape0::kN) return Status::kErrorInvalidProblem; - + if(problem_size_1.n() > B2bMma::Shape1::kN) return Status::kErrorInvalidProblem; @@ -247,37 +280,64 @@ struct B2bGemm { return; } + ElementA0 *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB0 *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB1 *ptr_B1 = static_cast(params.ref_B1.data()); + + ScaleBiasData *ptr_Bias0 = static_cast(params.ref_Bias0.data()); + ScaleBiasData *ptr_Scale0 = static_cast(params.ref_Scale0.data()); + + int offset_k_0 = 0; + int offset_k_1 = 0; + + int problem_size_k_0 = params.problem_size_0.k(); + int problem_size_k_1 = params.problem_size_1.k(); + + if (params.mode == GemmUniversalMode::kGemm) { + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_0 = min( + problem_size_k_0, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_1 = min( + problem_size_k_1, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); + + offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0; + offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1; + } + + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0; + ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0; + } + // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A0{ threadblock_tile_offset.m() * B2bMma::Shape0::kM, - threadblock_tile_offset.k() * params.gemm_k_size_0, + offset_k_0, }; cutlass::MatrixCoord tb_offset_B0{ - threadblock_tile_offset.k() * params.gemm_k_size_0, + offset_k_0, threadblock_tile_offset.n() * B2bMma::Shape0::kN }; cutlass::MatrixCoord tb_offset_B1{ - threadblock_tile_offset.k() * params.gemm_k_size_1, + offset_k_1, threadblock_tile_offset.n() * B2bMma::Shape1::kN }; - // Problem size is a function of threadblock index in the K dimension - int problem_size_k_0 = min( - params.problem_size_0.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); - // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; - // Problem size is a function of threadblock index in the K dimension - int problem_size_k_1 = min( - params.problem_size_1.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); - // Compute threadblock-scoped matrix multiply-add -// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; // Compute position within threadblock @@ -286,26 +346,25 @@ struct B2bGemm { // Construct iterators to A and B operands typename B2bMma::IteratorA0 iterator_A0( params.params_A0, - params.ref_A0.data(), + ptr_A0, {params.problem_size_0.m(), problem_size_k_0}, thread_idx, tb_offset_A0); typename B2bMma::IteratorB0 iterator_B0( params.params_B0, - params.ref_B0.data(), + ptr_B0, {problem_size_k_0, params.problem_size_0.n()}, thread_idx, tb_offset_B0); typename B2bMma::IteratorB1 iterator_B1( params.params_B1, - params.ref_B1.data(), + ptr_B1, {problem_size_k_1, params.problem_size_1.n()}, thread_idx, tb_offset_B1); - // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); @@ -313,7 +372,7 @@ struct B2bGemm { // Construct iterators to accumulator scale/bias vector typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( - params.ref_Scale0.data(), + ptr_Scale0, {1, params.problem_size_0.n()}, thread_idx, warp_idx, @@ -323,7 +382,7 @@ struct B2bGemm { ); typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( - params.ref_Bias0.data(), + ptr_Bias0, {1, params.problem_size_0.n()}, thread_idx, warp_idx, @@ -349,11 +408,9 @@ struct B2bGemm { src_accum.clear(); accumulators.clear(); - if (!kSplitKSerial || gemm_k_iterations_0 > 0) { - // Compute threadblock-scoped matrix multiply-add - b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, - iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); - } + // Compute threadblock-scoped matrix multiply-add + b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); // // Epilogue @@ -376,23 +433,32 @@ struct B2bGemm { int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); + if (params.mode == GemmUniversalMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization - // Indicate which position in a serial reduction the output operator is currently updating - output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + if (params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1; } // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C1( params.params_C1, - params.ref_C1.data(), + ptr_C1, params.problem_size_1.mn(), thread_idx, threadblock_offset @@ -401,21 +467,21 @@ struct B2bGemm { // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D1( params.params_D1, - params.ref_D1.data(), + ptr_D1, params.problem_size_1.mn(), thread_idx, threadblock_offset ); Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, + shared_storage.epilogue, + thread_idx, + warp_idx, lane_idx); // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. if (threadblock_tile_offset.k()) { iterator_C1 = iterator_D1; @@ -427,14 +493,14 @@ struct B2bGemm { } // Execute the epilogue operator to update the destination tensor. - epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); - + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + // // Release the semaphore // - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { @@ -457,4 +523,3 @@ struct B2bGemm { } // namespace kernel } // namespace gemm } // namespace cutlass - diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h index 05c3f4e2..3f54e1da 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h @@ -30,10 +30,10 @@ **************************************************************************************************/ /*! \file - \brief + \brief Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropriate threadblock-scoped epilogue. - + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are accommodated by exchanging A and B operands and assuming transposed layouts. Partial specializations here choose 'device::GemmTransposed' to implement this functionality. @@ -114,8 +114,6 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, - /// If true, kernel is configured to support serial reduction in the epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator, /// Stage accumulator in shared memory @@ -161,22 +159,19 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm { /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, - ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; @@ -188,7 +183,7 @@ struct DefaultB2bGemm::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; @@ -228,8 +223,6 @@ template < typename EpilogueOutputOp1, /// Threadblock-level swizzling operator typename ThreadblockSwizzle, - /// If true, kernel is configured to support serial reduction in the epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator > @@ -249,7 +242,6 @@ struct DefaultB2bGemm< EpilogueOutputOp1, ThreadblockSwizzle, 2, - SplitKSerial, Operator > { @@ -274,7 +266,7 @@ struct DefaultB2bGemm< Operator, EpilogueOutputOp0 >::ThreadblockB2bMma; - + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; /// Define the epilogue @@ -287,7 +279,7 @@ struct DefaultB2bGemm< >::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; @@ -323,20 +315,16 @@ template < int Stages, /// Number of Interleaved k int InterleavedK, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm< ElementA, layout::ColumnMajorInterleaved, kAlignmentA, - ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, int32_t, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, - ThreadblockSwizzle, Stages, - SplitKSerial, Operator> { + ThreadblockSwizzle, Stages, Operator> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -360,7 +348,7 @@ struct DefaultB2bGemm< 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; //////////////////////////////////////////////////////////////////////////////// @@ -396,19 +384,16 @@ template < typename ThreadblockSwizzle, /// Number of Interleaved k int InterleavedK, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm, kAlignmentA, ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, - int32_t, arch::OpClassTensorOp, arch::Sm75, + int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, - ThreadblockSwizzle, 2, SplitKSerial, Operator> { + ThreadblockSwizzle, 2, Operator> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -418,7 +403,7 @@ struct DefaultB2bGemm, /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, - arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, + arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma; static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; @@ -430,7 +415,7 @@ struct DefaultB2bGemm, 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h index 23717c61..6849d6e8 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h @@ -30,10 +30,10 @@ **************************************************************************************************/ /*! \file - \brief + \brief Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropriate threadblock-scoped epilogue. - + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are accommodated by exchanging A and B operands and assuming transposed layouts. Partial specializations here choose 'device::GemmTransposed' to implement this functionality. @@ -112,22 +112,19 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm { /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, - ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma; static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; @@ -139,10 +136,9 @@ struct DefaultB2bGemm::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; - //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Turing Architecture @@ -179,8 +175,6 @@ template < typename EpilogueOutputOp1, /// Threadblock-level swizzling operator typename ThreadblockSwizzle, - /// If true, kernel is configured to support serial reduction in the epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator > @@ -200,7 +194,6 @@ struct DefaultB2bGemm< EpilogueOutputOp1, ThreadblockSwizzle, 2, - SplitKSerial, Operator, true > { @@ -228,7 +221,7 @@ struct DefaultB2bGemm< false, true >::ThreadblockB2bMma; - + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; /// Define the epilogue @@ -241,7 +234,7 @@ struct DefaultB2bGemm< >::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; @@ -277,20 +270,17 @@ template < int Stages, /// Number of Interleaved k int InterleavedK, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm< ElementA, layout::ColumnMajorInterleaved, kAlignmentA, - ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, int32_t, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, - SplitKSerial, Operator, true> { + Operator, true> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -314,7 +304,7 @@ struct DefaultB2bGemm< 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; //////////////////////////////////////////////////////////////////////////////// @@ -350,19 +340,16 @@ template < typename ThreadblockSwizzle, /// Number of Interleaved k int InterleavedK, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm, kAlignmentA, ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, - int32_t, arch::OpClassTensorOp, arch::Sm75, + int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, - ThreadblockSwizzle, 2, SplitKSerial, Operator, true> { + ThreadblockSwizzle, 2, Operator, true> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -371,9 +358,9 @@ struct DefaultB2bGemm, /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< - ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, - ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, - ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma; static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; @@ -385,7 +372,7 @@ struct DefaultB2bGemm, 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h b/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h index eef9d9a1..67de37ed 100644 --- a/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h +++ b/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h @@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm( TensorRefScalar tensor_scale, ///< scale tensor TensorRefScalar tensor_bias ///< bias tensor ) { - + ConvertOp convert_op; MatrixCoord output_coord( @@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm( ScalarType bias = ScalarType(0); - if(tensor_bias.good()) + if(tensor_bias.good()) bias = tensor_bias.at({0, coord.column()}); tensor_out.at(coord) = convert_op( @@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm( } } +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename ConvertOp = NumericConverter, + int kMblock = 4, + int kNblock = 4 +> +__global__ void TensorScaleBiasGemmBatched( + gemm::GemmCoord problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias, ///< bias tensor + int batch_count = 1, + int64_t batch_stride_tensor_in = 0, + int64_t batch_stride_tensor_out = 0, + int64_t batch_stride_tensor_scale = 0, + int64_t batch_stride_tensor_bias = 0 +) { + + ConvertOp convert_op; + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in); + tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out); + tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale); + tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + MatrixCoord coord = MatrixCoord(row, col); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, coord.column()}); + + ScalarType bias = ScalarType(0); + + if(tensor_bias.good()) + bias = tensor_bias.at({0, coord.column()}); + + tensor_out.at(coord) = convert_op( + scale * ScalarType(tensor_in.at(coord)) + bias); + } + } + } + tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z); + tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z); + tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z); + tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z); + } +} + template < typename TensorRefIn, ///< Input TensorRef Type typename TensorRefOut, ///< Output TensorRef Type @@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d( TensorRefScalar tensor_scale, ///< scale tensor TensorRefScalar tensor_bias ///< bias tensor ) { - + ConvertOp convert_op; int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; @@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d( int64_t npq = npq_start + m; thread_n[m] = int(npq / PQ); - + int64_t residual = npq % PQ; thread_p[m] = int(residual / problem_size.Q); thread_q[m] = int(residual % problem_size.Q); @@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d( ScalarType scale = alpha; if(tensor_scale.good()) scale = tensor_scale.at({0, thread_k}); - + ScalarType bias = ScalarType(0); - if(tensor_bias.good()) + if(tensor_bias.good()) bias = tensor_bias.at({0, thread_k}); - + tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( scale * ScalarType( tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) ) + bias); } - } + } } } @@ -217,6 +281,62 @@ void TensorScaleBiasGemm( ); } +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasGemmBatched( + gemm::GemmCoord problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias, ///< bias tensor + int batch_count = 1, + int64_t batch_stride_tensor_in = 0, + int64_t batch_stride_tensor_out = 0, + int64_t batch_stride_tensor_scale = 0, + int64_t batch_stride_tensor_bias = 0 +) { + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::TensorScaleBiasGemmBatched< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + ConvertOp, + kMblock, + kNblock + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias, + batch_count, + batch_stride_tensor_in, + batch_stride_tensor_out, + batch_stride_tensor_scale, + batch_stride_tensor_bias + ); +} + /// Apply scale and bias on a tensor template < typename ElementIn, ///< Input Type