From ec2b4fd85d2f0f53bc9f61791d6399d8bc6badb4 Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Sat, 30 Apr 2022 07:16:15 -0400 Subject: [PATCH] b2b bias vector support (#482) * b2b bias vector support * add files Co-authored-by: Haicheng Wu --- examples/13_two_tensor_op_fusion/README.md | 23 ++ .../13_two_tensor_op_fusion/b2b_conv2d_run.h | 31 +- .../13_two_tensor_op_fusion/b2b_gemm_run.h | 135 +++++++-- .../b2b_interleaved_conv2d_run.h | 43 ++- .../b2b_interleaved_gemm_run.h | 155 +++++++--- .../13_two_tensor_op_fusion/device/b2b_gemm.h | 26 +- .../fused_two_convs_f16_sm75_rf.cu | 17 +- .../fused_two_convs_f16_sm75_shmem.cu | 19 +- .../fused_two_convs_f16_sm80_rf.cu | 17 +- .../fused_two_convs_f16_sm80_shmem.cu | 17 +- .../fused_two_convs_s8_sm75_rf.cu | 27 +- .../fused_two_convs_s8_sm75_shmem.cu | 23 +- .../fused_two_convs_s8_sm80_rf.cu | 27 +- .../fused_two_convs_s8_sm80_shmem.cu | 21 +- .../fused_two_gemms_f16_sm75_rf.cu | 23 +- .../fused_two_gemms_f16_sm75_shmem.cu | 25 +- .../fused_two_gemms_f16_sm80_rf.cu | 41 +-- .../fused_two_gemms_f16_sm80_shmem.cu | 25 +- .../fused_two_gemms_s8_sm75_rf.cu | 25 +- .../fused_two_gemms_s8_sm75_shmem.cu | 33 ++- .../fused_two_gemms_s8_sm80_rf.cu | 33 ++- .../fused_two_gemms_s8_sm80_shmem.cu | 27 +- .../13_two_tensor_op_fusion/kernel/b2b_gemm.h | 32 +- ...t_b2b_conv2d_fprop_smem_accumulator_sm80.h | 2 +- .../reference/device/tensor_scale_bias.h | 275 ++++++++++++++++++ .../b2b_implicit_gemm_multistage.h | 1 - .../threadblock/b2b_mma_multistage.h | 92 +++++- .../b2b_mma_multistage_smem_accumulator.h | 15 +- .../threadblock/b2b_mma_pipelined.h | 63 +++- .../b2b_mma_pipelined_smem_accumulator.h | 4 +- .../threadblock/default_b2b_mma.h | 76 ++++- .../default_b2b_mma_smem_accumulator.h | 2 +- .../epilogue/thread/linear_combination_relu.h | 4 + .../util/reference/device/convolution.h | 41 +-- 34 files changed, 1096 insertions(+), 324 deletions(-) create mode 100644 examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h diff --git a/examples/13_two_tensor_op_fusion/README.md b/examples/13_two_tensor_op_fusion/README.md index 4ff19e2c..134644a0 100644 --- a/examples/13_two_tensor_op_fusion/README.md +++ b/examples/13_two_tensor_op_fusion/README.md @@ -61,6 +61,29 @@ When applying the above constraint to convolutions, it is required that the 2nd kernel doesn't have halos such that data used by each threadblock doesn't depend on any other threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings. +# Build and run + +- Run cmake at top-level CUTLASS +- `make 13_two_tensor_op_fusion` +- Run individual benchmarks + - `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm75_rf` + - `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm75_shmem` + - `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm80_rf` + - `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm80_shmem` + - `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm75_rf` + - `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm75_shmem` + - `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm80_rf` + - `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm80_shmem` + - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm75_rf` + - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm75_shmem` + - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_rf` + - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_shmem` + - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_rf` + - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem` + - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf` + - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem` + + # Copyright Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h b/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h index 7fa1a28b..b0509063 100644 --- a/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h @@ -54,6 +54,7 @@ #include "cutlass/core_io.h" #include "cutlass/util/tensor_view_io.h" +#include "reference/device/tensor_scale_bias.h" #include "helper.h" #define CHECK_GT(val1, val2) \ @@ -153,6 +154,7 @@ public: cutlass::reference::host::TensorFill(view, Element(1)); } else { + std::cerr << "Not implemented\n"; } } @@ -407,6 +409,7 @@ public: cutlass::HostTensor tensor_C0; cutlass::HostTensor tensor_Scale0; cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_Z0_reference; cutlass::HostTensor tensor_D0_reference; cutlass::HostTensor tensor_B1; @@ -487,6 +490,7 @@ public: if(alpha0 == ElementCompute(0)) //per-channel scale tensor_Scale0.resize({1, problem_size_0.K}); tensor_Bias0.resize({1, problem_size_0.K}); + tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); @@ -607,22 +611,35 @@ public: typename B2bConv2d::LayoutA, typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB, - typename B2bConv2d::ElementC, + ElementAccumulator, typename B2bConv2d::LayoutC, - ElementCompute, + ElementAccumulator, ElementAccumulator >( kConvolutionalOperator, problem_size_0, tensor_A0.device_ref(), tensor_B0.device_ref(), - tensor_C0.device_ref(), + tensor_Z0_reference.device_ref(), + tensor_Z0_reference.device_ref(), + ElementAccumulator(1), // intermediate alpha = 1 + ElementAccumulator(0) // beta = 0 + ); + + cutlass::reference::device::TensorScaleBiasConv2d< + ElementAccumulator, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + typename B2bConv2d::LayoutScaleBias + >( + problem_size_0, + tensor_Z0_reference.device_ref(), tensor_D0_reference.device_ref(), - alpha0, - beta0, - nullptr, // stream + alpha0, tensor_Scale0.device_ref(), - tensor_Bias0.device_ref()); + tensor_Bias0.device_ref() + ); if(relu) { cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); diff --git a/examples/13_two_tensor_op_fusion/b2b_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_gemm_run.h index bebc058f..6cc4ffd9 100644 --- a/examples/13_two_tensor_op_fusion/b2b_gemm_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_gemm_run.h @@ -44,6 +44,7 @@ #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_relu.h" +#include "reference/device/tensor_scale_bias.h" #include "helper.h" #define CHECK_GT(val1, val2) \ @@ -68,6 +69,7 @@ struct B2bNonFusedGemmRun cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; uint64_t seed; // @@ -78,9 +80,10 @@ struct B2bNonFusedGemmRun cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } /// Helper to initialize a tensor view template @@ -97,7 +100,7 @@ struct B2bNonFusedGemmRun else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } + } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); @@ -106,9 +109,14 @@ struct B2bNonFusedGemmRun cutlass::reference::host::BlockFillSequential( view.data(), view.capacity()); - } + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } else { - // TODO: Implement the rest std::cerr << "Not implemented\n"; return false; } @@ -147,6 +155,10 @@ struct B2bNonFusedGemmRun typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); + cutlass::HostTensor< + ElementCompute, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); + cutlass::HostTensor< typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); @@ -163,6 +175,10 @@ struct B2bNonFusedGemmRun typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); + cutlass::HostTensor< + ElementCompute, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); + cutlass::HostTensor< typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); @@ -175,8 +191,10 @@ struct B2bNonFusedGemmRun CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); cutlass::reference::host::TensorFill( tensor_D0.host_view()); @@ -190,9 +208,11 @@ struct B2bNonFusedGemmRun tensor_A0.sync_device(); tensor_B0.sync_device(); tensor_C0.sync_device(); + tensor_Bias0.sync_device(); tensor_D0.sync_device(); tensor_B1.sync_device(); tensor_C1.sync_device(); + tensor_Bias1.sync_device(); tensor_D1.sync_device(); reference_D0.sync_device(); reference_D1.sync_device(); @@ -205,7 +225,7 @@ struct B2bNonFusedGemmRun problem_size_0, tensor_A0.device_ref(), tensor_B0.device_ref(), - tensor_C0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, tensor_D0.device_ref(), {alpha0, beta0} }; @@ -214,7 +234,7 @@ struct B2bNonFusedGemmRun problem_size_1, tensor_D0.device_ref(), tensor_B1.device_ref(), - tensor_C1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, tensor_D1.device_ref(), {alpha1, beta1} }; @@ -241,7 +261,6 @@ struct B2bNonFusedGemmRun // // Run the GEMM // - cudaEvent_t start, stop1, stop2; cudaEventCreate(&start); cudaEventCreate(&stop1); @@ -256,7 +275,6 @@ struct B2bNonFusedGemmRun } cudaEventRecord(stop1); for(int i = 0; i < runs; i++) { - status = gemm_op_1(); CUTLASS_CHECK(status); @@ -298,7 +316,7 @@ struct B2bNonFusedGemmRun tensor_A0.device_ref(), tensor_B0.device_ref(), beta0, - tensor_C0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, reference_D0.device_ref() ); @@ -312,7 +330,7 @@ struct B2bNonFusedGemmRun reference_D0.device_ref(), tensor_B1.device_ref(), beta1, - tensor_C1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, reference_D1.device_ref() ); @@ -325,7 +343,6 @@ struct B2bNonFusedGemmRun reference_D0.sync_host(); reference_D1.sync_host(); - CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); @@ -349,13 +366,14 @@ struct B2bNonFusedGemmRun << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" << "\nD0 =\n" << tensor_D0.host_view() << "\nB1 =\n" << tensor_B1.host_view() << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" << "\n\nReference =\n" << reference_D1.host_view() << "\nComputed =\n" << tensor_D1.host_view(); } - return passed; } }; @@ -372,6 +390,8 @@ struct B2bFusedGemmRun cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; uint64_t seed; // @@ -382,9 +402,12 @@ struct B2bFusedGemmRun cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } /// Helper to initialize a tensor view template @@ -410,9 +433,14 @@ struct B2bFusedGemmRun cutlass::reference::host::BlockFillSequential( view.data(), view.capacity()); - } + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } else { - // TODO: Implement the rest std::cerr << "Not implemented\n"; return false; } @@ -451,6 +479,21 @@ struct B2bFusedGemmRun typename B2bGemm::ElementC, typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Scale0; + + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, problem_size_0.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); + + cutlass::HostTensor< + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); + cutlass::HostTensor< typename B2bGemm::ElementC, typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); @@ -463,6 +506,10 @@ struct B2bFusedGemmRun typename B2bGemm::ElementC, typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); + cutlass::HostTensor< + ElementCompute, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); + cutlass::HostTensor< typename B2bGemm::ElementC, typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); @@ -475,21 +522,29 @@ struct B2bFusedGemmRun CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + if(alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); cutlass::reference::host::TensorFill( tensor_D1.host_view()); cutlass::reference::host::TensorFill( - reference_D0.host_view()); + reference_D0.host_view()); cutlass::reference::host::TensorFill( reference_D1.host_view()); tensor_A0.sync_device(); tensor_B0.sync_device(); tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); tensor_B1.sync_device(); tensor_C1.sync_device(); + tensor_Bias1.sync_device(); tensor_D1.sync_device(); reference_D0.sync_device(); reference_D1.sync_device(); @@ -504,8 +559,10 @@ struct B2bFusedGemmRun tensor_A0.device_ref(), tensor_B0.device_ref(), tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), tensor_B1.device_ref(), - tensor_C1.device_ref(), + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, tensor_D1.device_ref(), {alpha0, beta0}, {alpha1, beta1}, @@ -524,7 +581,6 @@ struct B2bFusedGemmRun << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; } - status = b2b_gemm_op.initialize(arguments); CUTLASS_CHECK(status); @@ -561,21 +617,42 @@ struct B2bFusedGemmRun // // Verify // + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator> + reference_gemm_0; + cutlass::reference::device::Gemm< typename B2bGemm::ElementA, typename B2bGemm::LayoutA, typename B2bGemm::ElementB, typename B2bGemm::LayoutB, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, ElementAccumulator, typename B2bGemm::Operator> - reference_gemm_0, reference_gemm_1; + reference_gemm_1; reference_gemm_0( problem_size_0, - alpha0, + ElementAccumulator(1), //intermediate alpha=1 tensor_A0.device_ref(), tensor_B0.device_ref(), - beta0, - tensor_C0.device_ref(), - reference_D0.device_ref() + ElementAccumulator(0), //beta = 0 + reference_Z0.device_ref(), + reference_Z0.device_ref(), + ElementAccumulator(0) + ); + + cutlass::reference::device::TensorScaleBiasGemm< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutScaleBias + > ( + problem_size_0, + reference_Z0.device_ref(), + reference_D0.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref() ); if(relu) { @@ -588,18 +665,15 @@ struct B2bFusedGemmRun reference_D0.device_ref(), tensor_B1.device_ref(), beta1, - tensor_C1.device_ref(), + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, reference_D1.device_ref() ); - if(relu) { cutlass::reference::device::TensorReLu(reference_D1.device_view()); } - cudaDeviceSynchronize(); reference_D0.sync_host(); reference_D1.sync_host(); - CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); @@ -610,7 +684,8 @@ struct B2bFusedGemmRun tensor_D1.host_view()); CHECK_TRUE(passed); - if (!passed) { + if (!passed) + { std::stringstream fname; @@ -623,12 +698,14 @@ struct B2bFusedGemmRun << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nC0 =\n" << tensor_C0.host_view() + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" << "\nB1 =\n" << tensor_B1.host_view() << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" << "\n\nReference =\n" << reference_D1.host_view() << "\nComputed =\n" << tensor_D1.host_view(); } - return passed; } diff --git a/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h b/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h index fbdcc22b..f9905fa5 100644 --- a/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h @@ -55,6 +55,7 @@ #include "cutlass/core_io.h" #include "cutlass/util/tensor_view_io.h" +#include "reference/device/tensor_scale_bias.h" #include "helper.h" #define CHECK_GT(val1, val2) \ @@ -91,14 +92,14 @@ public: cutlass::HostTensor tensor_B0; cutlass::HostTensor tensor_B0_reordered; cutlass::HostTensor tensor_C0; - cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_Bias0; cutlass::HostTensor tensor_D0_computed; cutlass::HostTensor tensor_D0_reference; cutlass::HostTensor tensor_B1; cutlass::HostTensor tensor_B1_reordered; cutlass::HostTensor tensor_C1; - cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_Bias1; cutlass::HostTensor tensor_D1_computed; cutlass::HostTensor tensor_D1_reference; @@ -379,11 +380,13 @@ public: << "\nB0:\n" << tensor_B0.host_view() << "\n" << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" << "\nB1:\n" << tensor_B1.host_view() << "\n" << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" << "\nD1 computed:\n" << tensor_D1_computed.host_view(); @@ -421,12 +424,13 @@ public: cutlass::HostTensor tensor_C0; cutlass::HostTensor tensor_Scale0; cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_Z0_reference; cutlass::HostTensor tensor_D0_reference; cutlass::HostTensor tensor_B1; cutlass::HostTensor tensor_B1_reordered; cutlass::HostTensor tensor_C1; - cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_Bias1; cutlass::HostTensor tensor_D1_computed; cutlass::HostTensor tensor_D1_reference; @@ -503,6 +507,7 @@ public: if(alpha0 == ElementCompute(0)) //per-channel scale tensor_Scale0.resize({1, problem_size_0.K}); tensor_Bias0.resize({1, problem_size_0.K}); + tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); @@ -632,23 +637,36 @@ public: typename B2bConv2d::LayoutA, typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB, - typename B2bConv2d::ElementC, - typename B2bConv2d::LayoutC, - ElementCompute, ElementAccumulator, - cutlass::NumericConverterClamp + typename B2bConv2d::LayoutC, + ElementAccumulator, + ElementAccumulator >( kConvolutionalOperator, problem_size_0, tensor_A0.device_ref(), tensor_B0.device_ref(), - tensor_C0.device_ref(), + tensor_Z0_reference.device_ref(), + tensor_Z0_reference.device_ref(), + ElementAccumulator(1), // intermediate alpha = 1 + ElementAccumulator(0) // beta = 0 + ); + + cutlass::reference::device::TensorScaleBiasConv2d< + ElementAccumulator, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + typename B2bConv2d::LayoutScaleBias, + cutlass::NumericConverterClamp + >( + problem_size_0, + tensor_Z0_reference.device_ref(), tensor_D0_reference.device_ref(), - alpha0, - beta0, - nullptr, // stream + alpha0, tensor_Scale0.device_ref(), - tensor_Bias0.device_ref()); + tensor_Bias0.device_ref() + ); if(relu) { cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); @@ -716,6 +734,7 @@ public: << "\nB1:\n" << tensor_B1.host_view() << "\n" << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" << "\nD1 computed:\n" << tensor_D1_computed.host_view(); diff --git a/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h index 693e252e..95c404d9 100644 --- a/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h @@ -28,7 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - #pragma once #include @@ -46,6 +45,7 @@ #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_relu.h" +#include "reference/device/tensor_scale_bias.h" #include "helper.h" #define CHECK_GT(val1, val2) \ @@ -68,6 +68,7 @@ struct B2bInterleavedNonFusedGemmRun cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; uint64_t seed; // @@ -78,9 +79,10 @@ struct B2bInterleavedNonFusedGemmRun cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } /// Helper to initialize a tensor view template @@ -97,14 +99,23 @@ struct B2bInterleavedNonFusedGemmRun else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } else if (dist_kind == cutlass::Distribution::Sequential) { cutlass::reference::host::BlockFillSequential( view.data(), view.capacity()); - } + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } else { - // TODO: Implement the rest std::cerr << "Not implemented\n"; return false; } @@ -147,6 +158,10 @@ struct B2bInterleavedNonFusedGemmRun typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); + cutlass::HostTensor< typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); @@ -167,6 +182,10 @@ struct B2bInterleavedNonFusedGemmRun typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); + cutlass::HostTensor< typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); @@ -179,8 +198,10 @@ struct B2bInterleavedNonFusedGemmRun CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); //Reorder B0 and B1 cutlass::reorder_column( @@ -201,10 +222,12 @@ struct B2bInterleavedNonFusedGemmRun tensor_B0.sync_device(); tensor_B0_reordered.sync_device(); tensor_C0.sync_device(); + tensor_Bias0.sync_device(); tensor_D0.sync_device(); tensor_B1.sync_device(); tensor_B1_reordered.sync_device(); tensor_C1.sync_device(); + tensor_Bias1.sync_device(); tensor_D1.sync_device(); reference_D0.sync_device(); reference_D1.sync_device(); @@ -217,7 +240,7 @@ struct B2bInterleavedNonFusedGemmRun problem_size_0, tensor_A0.device_ref(), tensor_B0_reordered.device_ref(), - tensor_C0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, tensor_D0.device_ref(), {alpha0, beta0} }; @@ -226,7 +249,7 @@ struct B2bInterleavedNonFusedGemmRun problem_size_1, tensor_D0.device_ref(), tensor_B1_reordered.device_ref(), - tensor_C1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, tensor_D1.device_ref(), {alpha1, beta1} }; @@ -265,8 +288,7 @@ struct B2bInterleavedNonFusedGemmRun CUTLASS_CHECK(status); } - cudaEventRecord(stop1); - + cudaEventRecord(stop1); for(int i = 0; i < runs; i++) { status = gemm_op_1(); @@ -286,7 +308,6 @@ struct B2bInterleavedNonFusedGemmRun tensor_D0.sync_host(); tensor_D1.sync_host(); - bool passed = false; // // Verify // @@ -310,7 +331,7 @@ struct B2bInterleavedNonFusedGemmRun tensor_A0.device_ref(), tensor_B0.device_ref(), beta0, - tensor_C0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, reference_D0.device_ref() ); @@ -323,8 +344,8 @@ struct B2bInterleavedNonFusedGemmRun alpha1, reference_D0.device_ref(), tensor_B1.device_ref(), - beta1, - tensor_C1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, reference_D1.device_ref() ); @@ -332,6 +353,7 @@ struct B2bInterleavedNonFusedGemmRun cutlass::reference::device::TensorReLu(reference_D1.device_view()); } + // Wait for kernels to finish cudaDeviceSynchronize(); reference_D0.sync_host(); reference_D1.sync_host(); @@ -341,7 +363,7 @@ struct B2bInterleavedNonFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); - passed = cutlass::reference::host::TensorEquals( + bool passed = cutlass::reference::host::TensorEquals( reference_D1.host_view(), tensor_D1.host_view()); @@ -360,10 +382,12 @@ struct B2bInterleavedNonFusedGemmRun << "\nB0 =\n" << tensor_B0.host_view() << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" << "\nD0 =\n" << tensor_D0.host_view() << "\nB1 =\n" << tensor_B1.host_view() << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" << "\n\nReference =\n" << reference_D1.host_view() << "\nComputed =\n" << tensor_D1.host_view(); } @@ -383,6 +407,8 @@ struct B2bInterleavedFusedGemmRun cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; uint64_t seed; // @@ -393,9 +419,12 @@ struct B2bInterleavedFusedGemmRun cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): - init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } /// Helper to initialize a tensor view template @@ -413,13 +442,22 @@ struct B2bInterleavedFusedGemmRun cutlass::reference::host::TensorFillIdentity(view); } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } else if (dist_kind == cutlass::Distribution::Sequential) { cutlass::reference::host::BlockFillSequential( view.data(), view.capacity()); - } + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } else { - // TODO: Implement the rest std::cerr << "Not implemented\n"; return false; } @@ -437,7 +475,7 @@ struct B2bInterleavedFusedGemmRun ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), ElementCompute alpha1 = ElementCompute(1), - ElementCompute beta1 = ElementCompute(0), + ElementCompute beta1 = ElementCompute(0), bool relu = true, int warm_ups = 1, int runs = 100) { @@ -462,6 +500,21 @@ struct B2bInterleavedFusedGemmRun typename B2bGemm::ElementC, typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Scale0; + + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, problem_size_0.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); + + cutlass::HostTensor< + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); + cutlass::HostTensor< typename B2bGemm::ElementC, typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); @@ -478,6 +531,10 @@ struct B2bInterleavedFusedGemmRun typename B2bGemm::ElementC, typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); + cutlass::HostTensor< typename B2bGemm::ElementC, typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); @@ -490,8 +547,12 @@ struct B2bInterleavedFusedGemmRun CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + if(alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); //Reorder B0 cutlass::reorder_column<16>( @@ -510,9 +571,13 @@ struct B2bInterleavedFusedGemmRun tensor_B0.sync_device(); tensor_B0_reordered.sync_device(); tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); tensor_B1.sync_device(); tensor_B1_reordered.sync_device(); tensor_C1.sync_device(); + tensor_Bias1.sync_device(); tensor_D1.sync_device(); reference_D0.sync_device(); reference_D1.sync_device(); @@ -527,12 +592,13 @@ struct B2bInterleavedFusedGemmRun tensor_A0.device_ref(), tensor_B0_reordered.device_ref(), tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), tensor_B1_reordered.device_ref(), - tensor_C1.device_ref(), + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, tensor_D1.device_ref(), {alpha0, beta0}, {alpha1, beta1}, - 1, /*threadblock_swizzle_k_tile*/ }; B2bGemm b2b_gemm_op; @@ -581,25 +647,45 @@ struct B2bInterleavedFusedGemmRun tensor_D1.sync_host(); - bool passed = false; // // Verify // + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator> + reference_gemm_0; + cutlass::reference::device::Gemm< typename B2bGemm::ElementA, typename B2bGemm::LayoutA, typename B2bGemm::ElementB, typename B2bGemm::LayoutB, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, ElementAccumulator, typename B2bGemm::Operator> - reference_gemm_0, reference_gemm_1; + reference_gemm_1; reference_gemm_0( problem_size_0, - alpha0, + ElementAccumulator(1), //intermediate alpha=1 tensor_A0.device_ref(), tensor_B0.device_ref(), - beta0, - tensor_C0.device_ref(), - reference_D0.device_ref() + ElementAccumulator(0), //beta = 0 + reference_Z0.device_ref(), + reference_Z0.device_ref(), + ElementAccumulator(0) + ); + + cutlass::reference::device::TensorScaleBiasGemm< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutScaleBias + > ( + problem_size_0, + reference_Z0.device_ref(), + reference_D0.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref() ); if(relu) { @@ -612,29 +698,27 @@ struct B2bInterleavedFusedGemmRun reference_D0.device_ref(), tensor_B1.device_ref(), beta1, - tensor_C1.device_ref(), + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, reference_D1.device_ref() ); - - if(relu) { cutlass::reference::device::TensorReLu(reference_D1.device_view()); } - cudaDeviceSynchronize(); - reference_D0.sync_host(); - reference_D1.sync_host(); + reference_D0.sync_host(); + reference_D1.sync_host(); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); - passed = cutlass::reference::host::TensorEquals( + bool passed = cutlass::reference::host::TensorEquals( reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); - if (!passed) { + if (!passed) + { std::stringstream fname; @@ -648,9 +732,12 @@ struct B2bInterleavedFusedGemmRun << "\nB0 =\n" << tensor_B0.host_view() << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() << "\nC0 =\n" << tensor_C0.host_view() + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" << "\nB1 =\n" << tensor_B1.host_view() << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" << "\n\nReference =\n" << reference_D1.host_view() << "\nComputed =\n" << tensor_D1.host_view(); } diff --git a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h index 54b58d3e..3751cc82 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h @@ -158,6 +158,10 @@ class B2bGemm { static ComplexTransform const kTransformA = ComplexTransform::kNone; static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Derived types + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; + /// Define the kernel using B2bGemmKernel = typename kernel::DefaultB2bGemm< ElementA, @@ -197,6 +201,8 @@ class B2bGemm { TensorRef ref_A0; TensorRef ref_B0; TensorRef ref_C0; + TensorRef ref_Scale0; + TensorRef ref_Bias0; TensorRef ref_B1; TensorRef ref_C1; TensorRef ref_D1; @@ -222,6 +228,8 @@ class B2bGemm { TensorRef ref_A0_, TensorRef ref_B0_, TensorRef ref_C0_, + TensorRef ref_Scale0_, + TensorRef ref_Bias0_, TensorRef ref_B1_, TensorRef ref_C1_, TensorRef ref_D1_, @@ -236,6 +244,8 @@ class B2bGemm { ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), + ref_Bias0(ref_Bias0_), ref_B1(ref_B1_), ref_C1(ref_C1_), ref_D1(ref_D1_), @@ -348,6 +358,8 @@ public: args.ref_A0.non_const_ref(), args.ref_B0.non_const_ref(), args.ref_C0.non_const_ref(), + args.ref_Scale0.non_const_ref(), + args.ref_Bias0.non_const_ref(), args.ref_B1.non_const_ref(), args.ref_C1.non_const_ref(), args.ref_D1, @@ -368,12 +380,14 @@ public: } } - params_.ref_A0.reset(args.ref_A.non_const_ref().data()); - params_.ref_B0.reset(args.ref_B.non_const_ref().data()); - params_.ref_C0.reset(args.ref_C.non_const_ref().data()); - params_.ref_B1.reset(args.ref_B.non_const_ref().data()); - params_.ref_C1.reset(args.ref_C.non_const_ref().data()); - params_.ref_D1.reset(args.ref_D.data()); + params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); + params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data()); + params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data()); + params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D1.data()); params_.output_op_0 = args.epilogue0; params_.output_op_1 = args.epilogue1; params_.semaphore = static_cast(workspace); diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu index e5b0dd1f..d9c59db0 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu @@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { using ElementCompute = cutlass::half_t; ElementCompute alpha0 = ElementCompute(1); - ElementCompute beta0 = ElementCompute(0); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(1); //use beta for bias + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< @@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -151,14 +151,15 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() { using ElementCompute = cutlass::half_t; ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); ElementCompute alpha1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //use beta for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using EpilogueOutputOp0 = diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu index e549ccc3..54a13159 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu @@ -68,13 +68,13 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { using ElementCompute = cutlass::half_t; ElementCompute alpha0 = ElementCompute(1); - ElementCompute beta0 = ElementCompute(0); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; @@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -118,7 +118,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -151,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() { using ElementCompute = cutlass::half_t; ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; @@ -176,7 +177,7 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu index 13a2a9d9..7a66f8e8 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu @@ -69,14 +69,14 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { using ElementCompute = cutlass::half_t; ElementCompute alpha0 = ElementCompute(1); - ElementCompute beta0 = ElementCompute(0); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< @@ -94,7 +94,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -118,7 +118,8 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() { using ElementCompute = cutlass::half_t; ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; @@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() { ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu index deacaa8e..5a607141 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu @@ -69,13 +69,13 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { using ElementCompute = cutlass::half_t; ElementCompute alpha0 = ElementCompute(1); - ElementCompute beta0 = ElementCompute(0); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -94,7 +94,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -118,7 +118,8 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() { ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -151,9 +152,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() { using ElementCompute = cutlass::half_t; ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; @@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() { ElementC, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; const bool SmemAccumulator = true; diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu index 35f3f094..2481fbd8 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu @@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { using ElementCompute = float; ElementCompute alpha0 = ElementCompute(1); - ElementCompute beta0 = ElementCompute(0); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< @@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -151,14 +152,15 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() { using ElementCompute = float; ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using EpilogueOutputOp0 = @@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu index e7babbe4..917ae930 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu @@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { using ElementCompute = float; ElementCompute alpha0 = ElementCompute(1); - ElementCompute beta0 = ElementCompute(0); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; - using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< @@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2, @@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() { using ElementCompute = float; ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; @@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu index ac193b99..a515f125 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu @@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { using ElementCompute = float; ElementCompute alpha0 = ElementCompute(1); - ElementCompute beta0 = ElementCompute(0); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< @@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -151,14 +152,15 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() { using ElementCompute = float; ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>; + using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using EpilogueOutputOp0 = @@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu index 07e87369..9a5b2c1c 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu @@ -68,13 +68,13 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { using ElementCompute = float; ElementCompute alpha0 = ElementCompute(1); - ElementCompute beta0 = ElementCompute(0); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; @@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3, @@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() { using ElementCompute = float; ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); ElementCompute alpha1 = ElementCompute(1); - ElementCompute beta1 = ElementCompute(0); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; @@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() { ElementC, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; const bool SmemAccumulator = true; diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu index 9ee5bc0a..54c88355 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu @@ -55,10 +55,10 @@ bool run_nonfused_gemm_f16() { using ElementAccumulator = cutlass::half_t; using ElementCompute = cutlass::half_t; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; @@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -131,10 +132,11 @@ bool run_fused_gemm_f16_rf_res() { using ElementAccumulator = cutlass::half_t; using ElementCompute = cutlass::half_t; - ElementCompute alpha0 = ElementCompute(2); + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; @@ -156,7 +158,8 @@ bool run_fused_gemm_f16_rf_res() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; using B2bGemm = cutlass::gemm::device::B2bGemm< diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu index 0c9aa0e2..30ba2699 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu @@ -55,14 +55,14 @@ bool run_nonfused_gemm_f16() { using ElementAccumulator = cutlass::half_t; using ElementCompute = cutlass::half_t; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; @@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -130,10 +131,11 @@ bool run_fused_gemm_f16_shmem() { using ElementAccumulator = cutlass::half_t; using ElementCompute = cutlass::half_t; - ElementCompute alpha0 = ElementCompute(2); + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; @@ -155,7 +157,8 @@ bool run_fused_gemm_f16_shmem() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu index 6610737a..0c2239ac 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu @@ -55,15 +55,15 @@ bool run_nonfused_gemm_f16_sm80() { using ElementAccumulator = cutlass::half_t; using ElementCompute = cutlass::half_t; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using Gemm0 = cutlass::gemm::device::Gemm< @@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16_sm80() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3 @@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16_sm80() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3 @@ -130,15 +131,16 @@ bool run_fused_gemm_f16_sm80_rf_res() { using ElementAccumulator = cutlass::half_t; using ElementCompute = cutlass::half_t; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using EpilogueOutputOp0 = @@ -155,11 +157,10 @@ bool run_fused_gemm_f16_sm80_rf_res() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; - - using B2bGemm = cutlass::gemm::device::B2bGemm< cutlass::half_t, cutlass::layout::RowMajor, diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu index 48c31ae0..045e4a8e 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu @@ -55,10 +55,10 @@ bool run_nonfused_gemm_f16_sm80() { using ElementAccumulator = cutlass::half_t; using ElementCompute = cutlass::half_t; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; @@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16_sm80() { 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3 @@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16_sm80() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3 @@ -130,10 +131,11 @@ bool run_fused_gemm_f16_sm80_shmem() { using ElementAccumulator = cutlass::half_t; using ElementCompute = cutlass::half_t; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; @@ -155,7 +157,8 @@ bool run_fused_gemm_f16_sm80_shmem() { ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu index 2fff4d84..2c00eb86 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu @@ -55,10 +55,10 @@ bool run_nonfused_gemm_s8() { using ElementAccumulator = int32_t; using ElementCompute = float; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; @@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -106,7 +106,8 @@ bool run_nonfused_gemm_s8() { ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -131,10 +132,11 @@ bool run_fused_gemm_s8_rf_res() { using ElementAccumulator = int32_t; using ElementCompute = float; - ElementCompute alpha0 = ElementCompute(2); + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; @@ -156,7 +158,8 @@ bool run_fused_gemm_s8_rf_res() { ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; using B2bGemm = cutlass::gemm::device::B2bGemm< @@ -200,7 +203,7 @@ int main() { &run_fused_gemm_s8_rf_res }; - return testRun(75, funcs, "gemm f16 RF residency"); + return testRun(75, funcs, "gemm int8 RF residency"); } diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu index 952688d4..10f4cb7b 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu @@ -55,15 +55,15 @@ bool run_nonfused_gemm_s8() { using ElementAccumulator = int32_t; using ElementCompute = float; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; - using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using Gemm0 = cutlass::gemm::device::Gemm< @@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -106,7 +106,8 @@ bool run_nonfused_gemm_s8() { ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 2 @@ -130,10 +131,11 @@ bool run_fused_gemm_s8_shmem() { using ElementAccumulator = int32_t; using ElementCompute = float; - ElementCompute alpha0 = ElementCompute(2); + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(1); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; @@ -155,7 +157,8 @@ bool run_fused_gemm_s8_shmem() { ElementOutput, 64 / cutlass::sizeof_bits::value, ElementAccumulator, - ElementCompute + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; const bool SmemAccumulator = true; @@ -202,7 +205,7 @@ int main() { &run_fused_gemm_s8_shmem }; - return testRun(75, funcs, "gemm s8 shmem staing"); + return testRun(75, funcs, "gemm int8 shmem staing"); } diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu index a4405504..38845371 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu @@ -55,15 +55,15 @@ bool run_nonfused_gemm_s8_sm80() { using ElementAccumulator = int32_t; using ElementCompute = float; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(0); + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using Gemm0 = cutlass::gemm::device::Gemm< @@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8_sm80() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, @@ -111,7 +111,7 @@ bool run_nonfused_gemm_s8_sm80() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, @@ -140,15 +140,16 @@ bool run_fused_gemm_s8_sm80_rf_res() { using ElementAccumulator = int32_t; using ElementCompute = float; - ElementCompute alpha0 = ElementCompute(2); + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>; + using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using EpilogueOutputOp0 = @@ -166,7 +167,7 @@ bool run_fused_gemm_s8_sm80_rf_res() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; const bool SmemAccumulator = false; diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu index c88d9df1..7afe4409 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu @@ -55,14 +55,14 @@ bool run_nonfused_gemm_s8_sm80() { using ElementAccumulator = int32_t; using ElementCompute = float; - ElementCompute alpha0 = ElementCompute(2); - ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(0); + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(1); //beta=1 for bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias - using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; - using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; @@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8_sm80() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, @@ -111,7 +111,7 @@ bool run_nonfused_gemm_s8_sm80() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, @@ -139,10 +139,11 @@ bool run_fused_gemm_s8_sm80_shmem() { using ElementAccumulator = int32_t; using ElementCompute = float; - ElementCompute alpha0 = ElementCompute(2); + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 ElementCompute beta0 = ElementCompute(0); - ElementCompute alpha1 = ElementCompute(2); - ElementCompute beta1 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; @@ -165,7 +166,7 @@ bool run_fused_gemm_s8_sm80_shmem() { 64 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute, - cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + cutlass::epilogue::thread::ScaleType::NoBetaScaling >; const bool SmemAccumulator = true; diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h index 53880b21..306e8cf4 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -79,6 +79,8 @@ struct B2bGemm { typename B2bMma::IteratorB0::TensorRef ref_B0; typename Epilogue::OutputTileIterator::Params params_C0; typename Epilogue::OutputTileIterator::TensorRef ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0; typename B2bMma::IteratorB1::Params params_B1; typename B2bMma::IteratorB1::TensorRef ref_B1; typename Epilogue::OutputTileIterator::Params params_C1; @@ -109,6 +111,8 @@ struct B2bGemm { typename B2bMma::IteratorA0::TensorRef ref_A0, typename B2bMma::IteratorB0::TensorRef ref_B0, typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0, typename B2bMma::IteratorB1::TensorRef ref_B1, typename Epilogue::OutputTileIterator::TensorRef ref_C1, typename Epilogue::OutputTileIterator::TensorRef ref_D1, @@ -126,6 +130,8 @@ struct B2bGemm { ref_B0(ref_B0), params_C0(ref_C0.layout()), ref_C0(ref_C0), + ref_Scale0(ref_Scale0), + ref_Bias0(ref_Bias0), params_B1(ref_B1.layout()), ref_B1(ref_B1), params_C1(ref_C1.layout()), @@ -305,6 +311,29 @@ struct B2bGemm { int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; + // Construct iterators to accumulator scale/bias vector + typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( + params.ref_Scale0.data(), + {1, params.problem_size_0.n()}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN + ) + ); + + typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( + params.ref_Bias0.data(), + {1, params.problem_size_0.n()}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN + ) + ); + + + // // Main loop // @@ -322,7 +351,8 @@ struct B2bGemm { if (!kSplitKSerial || gemm_k_iterations_0 > 0) { // Compute threadblock-scoped matrix multiply-add - b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0); + b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); } // diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h index 179c3120..55619134 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h @@ -338,7 +338,7 @@ struct DefaultB2bConv2dFprop < cutlass::transform::threadblock::VectorIterator< cutlass::transform::threadblock::PredicatedVectorAccessIterator< cutlass::MatrixShape, - cutlass::MatrixShape, + cutlass::MatrixShape, ElementScaleBias, LayoutScaleBias, kElementsPerAccess> >; diff --git a/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h b/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h new file mode 100644 index 00000000..cc33731d --- /dev/null +++ b/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h @@ -0,0 +1,275 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/gemm/gemm.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename OutputTile, + typename ConvertOp = NumericConverter +> +__global__ void TensorScaleBiasGemm( + gemm::GemmCoord problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias ///< bias tensor +) { + + ConvertOp convert_op; + + MatrixCoord output_coord( + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) + ); + + // Update the output tensor + for (int j = 0; j < OutputTile::kRow; ++j) { + for (int i = 0; i < OutputTile::kColumn; ++i) { + MatrixCoord coord = output_coord + MatrixCoord(i, j); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, coord.column()}); + + ScalarType bias = ScalarType(0); + + if(tensor_bias.good()) + bias = tensor_bias.at({0, coord.column()}); + + tensor_out.at(coord) = convert_op( + scale * ScalarType(tensor_in.at(coord)) + bias); + } + } + } +} + +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename ConvertOp = NumericConverter, + int kThreadM = 4, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void TensorScaleBiasConv2d( + conv::Conv2dProblemSize problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias ///< bias tensor +) { + + ConvertOp convert_op; + + int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t npq = npq_start + m; + + thread_n[m] = int(npq / PQ); + + int64_t residual = npq % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, thread_k}); + + ScalarType bias = ScalarType(0); + if(tensor_bias.good()) + bias = tensor_bias.at({0, thread_k}); + + tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + scale * ScalarType( + tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) + ) + bias); + } + } + } + } + +} + +} + +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasGemm( + gemm::GemmCoord problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias ///< bias tensor +) { + + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + kernel::TensorScaleBiasGemm< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + OutputTile, + ConvertOp + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias + ); +} + +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasConv2d( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias ///< bias tensor +) { + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; + int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + + kernel::TensorScaleBiasConv2d< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + ConvertOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h index 228bff27..6229b595 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h @@ -745,7 +745,6 @@ public: this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_B1_; - if (warp_mma_k > 0) warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], warp_transformed_frag_B1[warp_mma_k % 2], diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h index 7f9dba0d..8104f638 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h @@ -82,6 +82,11 @@ template < /// Iterates over the intermediate accumulator tile // (concept::MmaTensorOpFragmentIterator) typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// WarpIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, /// Iterates over tiles of B operand in global memory // (concept: ReadableTileIterator | ForwardTileIterator | // MaskedTileIterator) @@ -126,6 +131,10 @@ public: using Shape1 = Shape1_; ///< Iterates over intermediate accumulator tile using FragmentIteratorA1 = FragmentIteratorA1_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< WarpIterator to load Scale or Bias vector from threadblock fragment + using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; ///< Iterates over tiles of B operand in global memory using IteratorB1 = IteratorB1_; ///< Policy describing tuning details @@ -140,6 +149,9 @@ public: ///< Epilogue after 1st Gemm using OutputOp = OutputOp_; + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; @@ -154,6 +166,9 @@ public: /// Warp-level Mma using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; /// Fragment of accumulator tile using FragmentC1 = typename Policy1::Operator::FragmentC; @@ -217,6 +232,8 @@ public: using WarpLoadedFragmentB0 = typename Operator0::FragmentB; /// Warp Fragment of operand A1 loaded from accmulator tile using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; + using WarpLoadedFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; using WarpLoadedFragmentB1 = typename Operator1::FragmentB; using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; @@ -381,11 +398,15 @@ public: int gemm_k_iterations_0, ///< destination accumulator tile FragmentC1 &accum, - ///< iterator over A operand in global memory + ///< iterator over A0 operand in global memory IteratorA0 iterator_A0, - ///< iterator over B operand in global memory + ///< iterator over B0 operand in global memory IteratorB0 iterator_B0, - ///< iterator over B operand in global memory + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, + ///< iterator over B1 operand in global memory IteratorB1 iterator_B1, ///< initial value of accumulator FragmentC0 const &src_accum, @@ -623,6 +644,20 @@ public: /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + // // Prologue @@ -678,18 +713,29 @@ public: // Pair of fragments used to overlap shared memory loads and math // instructions WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; Operator1 warp_mma1; - this->warp_tile_iterator_B1_.set_kgroup_index(0); - - warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0); - this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); + ++warp_tile_iterator_A1_bias_; + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], + warp_loaded_frag_A1_scale[0], + warp_loaded_frag_A1_bias[0], + output_op_0); ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); ++this->warp_tile_iterator_B1_; iterator_B1.clear_mask(gemm_k_iterations_1 == 0); @@ -717,15 +763,37 @@ public: for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + // Load threadblock-level scale/bias vector from global memory + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + // Load warp-level scale bias fragment from threadblock scale/bias vector + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_bias_; + + // Load warp-level tile from accumulator fragment + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + ++warp_tile_iterator_A1_; + // Load warp-level tiles from shared memory, wrapping to k offset if // this is the last group as the case may be. - this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); - - warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0); this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); - - ++warp_tile_iterator_A1_; ++this->warp_tile_iterator_B1_; if (warp_mma_k > 0) diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h index 2a5fb933..c28f4e49 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h @@ -165,6 +165,9 @@ public: /// Warp-level Mma using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; /// Fragment of accumulator tile using FragmentC1 = typename Policy1::Operator::FragmentC; @@ -418,11 +421,15 @@ public: int gemm_k_iterations_0, ///< destination accumulator tile FragmentC1 &accum, - ///< iterator over A operand in global memory + ///< iterator over A0 operand in global memory IteratorA0 iterator_A0, - ///< iterator over B operand in global memory + ///< iterator over B0 operand in global memory IteratorB0 iterator_B0, - ///< iterator over B operand in global memory + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, + ///< iterator over B1 operand in global memory IteratorB1 iterator_B1, ///< initial value of accumulator FragmentC0 const &src_accum, @@ -658,7 +665,7 @@ public: /// Epilogue for the first Implicit Gemm Epilogue0 epilogue0; - epilogue0(output_op_0, smem_iterator_D0_, accum0); + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); __syncthreads(); diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h index df272289..4e39fda5 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h @@ -76,6 +76,11 @@ template < /// Iterates over the intermediate accumulator tile // (concept::MmaTensorOpFragmentIterator) typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// FragmentIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, /// Iterates over tiles of B operand in global memory // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) typename IteratorB1_, @@ -129,6 +134,9 @@ public: using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using FragmentIteratorA1ScaleBias = + FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory using Policy1 = Policy1_; ///< Policy describing tuning details @@ -140,6 +148,9 @@ public: using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + using TransformA0 = TransformA0_; using TransformB0 = TransformB0_; using TransformB1 = TransformB1_; @@ -160,6 +171,9 @@ public: /// Warp-level Mma using Operator0 = typename Policy0::Operator; + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + /// Fragment of operand B loaded from global memory using FragmentB1 = typename IteratorB1::Fragment; @@ -190,6 +204,9 @@ private: using WarpFragmentB0 = typename Operator0::FragmentB; /// Warp Fragment of operand A1 loaded from accmulator tile using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; + /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment + using WarpFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; using WarpFragmentB1 = typename Operator1::FragmentB; protected: @@ -248,6 +265,8 @@ public: FragmentC1 &accum, ///< destination accumulator tile IteratorA0 iterator_A, ///< iterator over A operand in global memory IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory FragmentC0 const &src_accum, ///< source accumualtor tile OutputOp output_op_0, ///< epilogue operation after 1st Gemm @@ -387,13 +406,26 @@ public: // Prologue // + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); FragmentB1 tb_frag_B1; + if(PerChannelScale) + tb_frag_A1_scale.clear(); + tb_frag_A1_bias.clear(); tb_frag_B1.clear(); // The last kblock is loaded in the prolog + if(PerChannelScale) + iterator_A1_scale.load(tb_frag_A1_scale); + iterator_A1_bias.load(tb_frag_A1_bias); iterator_B1.load(tb_frag_B1); + if(PerChannelScale) + ++iterator_A1_scale; + ++iterator_A1_bias; ++iterator_B1; this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); @@ -403,15 +435,24 @@ public: __syncthreads(); // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; + WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; WarpFragmentA1 warp_frag_A1[2]; WarpFragmentB1 warp_frag_B1[2]; this->warp_tile_iterator_B1_.set_kgroup_index(0); - warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0); + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); + warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], + warp_frag_A1_bias[0], output_op_0); this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); ++warp_tile_iterator_A1_; + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; ++this->warp_tile_iterator_B1_; Operator1 warp_mma1; @@ -461,13 +502,31 @@ public: } smem_write_stage_idx ^= 1; + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; } this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); - warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0); + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], + warp_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; ++warp_tile_iterator_A1_; ++this->warp_tile_iterator_B1_; diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h index 4cd89ed9..b548c857 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -286,6 +286,8 @@ public: FragmentC1 &accum, ///< destination accumulator tile IteratorA0 iterator_A, ///< iterator over A operand in global memory IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory FragmentC0 const &src_accum, ///< source accumualtor tile OutputOp output_op_0, ///< epilogue operation after 1st Gemm @@ -419,7 +421,7 @@ public: /// Epilogue for the first Implicit Gemm Epilogue0 epilogue0; - epilogue0(output_op_0, smem_iterator_D0_, accum0); + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); __syncthreads(); diff --git a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h index 02e8e206..3c12e05c 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h +++ b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h @@ -40,6 +40,10 @@ #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h" @@ -170,6 +174,22 @@ struct DefaultB2bMma; + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + // Define iterators over tiles from the B operand using IteratorB1 = cutlass::transform::threadblock::PredicatedTileIterator< @@ -181,6 +201,7 @@ struct DefaultB2bMma; + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; using AccessTypeB1 = cutlass::Array; @@ -290,6 +329,7 @@ struct DefaultB2bMma; + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + // Define iterators over tiles from the B operand using IteratorB1 = cutlass::transform::threadblock::PredicatedTileIterator< @@ -384,12 +440,12 @@ struct DefaultB2bMma; - // Define the threadblock-scoped pipelined matrix multiply using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, IteratorB0, typename MmaCore0::SmemIteratorB, typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, IteratorB1, typename MmaCore1::SmemIteratorB, ElementAccumulator, layout::ColumnMajorInterleaved, EpilogueOutputOp, @@ -479,6 +535,23 @@ struct DefaultB2bMma; + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + // Define iterators over tiles from the B operand using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; using IteratorB1 = @@ -494,6 +567,7 @@ struct DefaultB2bMma, EpilogueOutputOp, diff --git a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h index b0a76896..ea1a258f 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h @@ -559,7 +559,7 @@ struct DefaultB2bMma, - cutlass::MatrixShape, + cutlass::MatrixShape, ElementScaleBias, LayoutScaleBias, kElementsPerAccess> >; diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 5743adc7..ea20021b 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -162,6 +162,8 @@ public: if (Scale == ScaleType::OnlyAlphaScaling) return false; + if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false; + if (Scale == ScaleType::Nothing) return false; return beta_ != ElementCompute(0); @@ -389,6 +391,8 @@ public: if (Scale == ScaleType::OnlyAlphaScaling) return false; + if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false; + if (Scale == ScaleType::Nothing) return false; return beta_ != ElementCompute(0); diff --git a/tools/util/include/cutlass/util/reference/device/convolution.h b/tools/util/include/cutlass/util/reference/device/convolution.h index 6f6ede63..8c00b779 100644 --- a/tools/util/include/cutlass/util/reference/device/convolution.h +++ b/tools/util/include/cutlass/util/reference/device/convolution.h @@ -82,9 +82,7 @@ __global__ void Conv2dFprop( TensorRef tensor_y_in, TensorRef tensor_y_out, ElementCompute alpha, - ElementCompute beta, - TensorRef tensor_scale, - TensorRef tensor_bias + ElementCompute beta ) { ConvertOp convert_op; @@ -186,26 +184,13 @@ __global__ void Conv2dFprop( int thread_k = k_start + n; if (thread_k < problem_size.K) { - if(alpha == ElementCompute()) { // use per-channel scale and bias - ElementCompute scale = tensor_scale.at({0, thread_k}); - ElementCompute bias = tensor_bias.at({0, thread_k}); - tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( - scale * ElementCompute(accum[m][n]) + bias); + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); } - else if(tensor_bias.good()) { // use per-channel bias - ElementCompute bias = tensor_bias.at({0, thread_k}); - tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( - alpha * ElementCompute(accum[m][n]) + bias); - } - else { - ElementCompute c_ref = ElementCompute(); - if (beta != ElementCompute()) { - c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); - } - tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( - alpha * ElementCompute(accum[m][n]) + beta * c_ref); - } + tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); } } } @@ -1015,9 +1000,7 @@ Status Conv2dFprop( TensorRef tensor_y_out, ElementCompute alpha, ElementCompute beta, - cudaStream_t stream = nullptr, - TensorRef tensor_scale = TensorRef(), - TensorRef tensor_bias = TensorRef() ) { + cudaStream_t stream = nullptr) { // // Blocking factors improve performance of reference implementation @@ -1056,9 +1039,7 @@ Status Conv2dFprop( tensor_y_in, tensor_y_out, alpha, - beta, - tensor_scale, - tensor_bias + beta ); cudaError_t result = cudaPeekAtLastError(); @@ -1448,9 +1429,7 @@ Status Conv2d( TensorRef tensor_D, ElementCompute alpha, ElementCompute beta, - cudaStream_t stream = nullptr, - TensorRef tensor_scale = TensorRef(), - TensorRef tensor_bias = TensorRef() ) { + cudaStream_t stream = nullptr) { switch (convolutional_operator) { case conv::Operator::kFprop: @@ -1461,7 +1440,7 @@ Status Conv2d( ElementCompute, ElementAccumulator, ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream, tensor_scale, tensor_bias); + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); break; case conv::Operator::kDgrad: