b2b bias vector support (#482)

* b2b bias vector support

* add files

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2022-04-30 07:16:15 -04:00 committed by GitHub
parent 86ce09aed1
commit ec2b4fd85d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1096 additions and 324 deletions

View File

@ -61,6 +61,29 @@ When applying the above constraint to convolutions, it is required that the 2nd
kernel doesn't have halos such that data used by each threadblock doesn't depend on any other kernel doesn't have halos such that data used by each threadblock doesn't depend on any other
threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings. threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings.
# Build and run
- Run cmake at top-level CUTLASS
- `make 13_two_tensor_op_fusion`
- Run individual benchmarks
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm75_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm75_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm80_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm80_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm75_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm75_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm80_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm80_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm75_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm75_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
# Copyright # Copyright
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

View File

@ -54,6 +54,7 @@
#include "cutlass/core_io.h" #include "cutlass/core_io.h"
#include "cutlass/util/tensor_view_io.h" #include "cutlass/util/tensor_view_io.h"
#include "reference/device/tensor_scale_bias.h"
#include "helper.h" #include "helper.h"
#define CHECK_GT(val1, val2) \ #define CHECK_GT(val1, val2) \
@ -153,6 +154,7 @@ public:
cutlass::reference::host::TensorFill(view, Element(1)); cutlass::reference::host::TensorFill(view, Element(1));
} }
else { else {
std::cerr << "Not implemented\n";
} }
} }
@ -407,6 +409,7 @@ public:
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0; cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0; cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0; cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference; cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1; cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
@ -487,6 +490,7 @@ public:
if(alpha0 == ElementCompute(0)) //per-channel scale if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.K}); tensor_Scale0.resize({1, problem_size_0.K});
tensor_Bias0.resize({1, problem_size_0.K}); tensor_Bias0.resize({1, problem_size_0.K});
tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
@ -607,22 +611,35 @@ public:
typename B2bConv2d::LayoutA, typename B2bConv2d::LayoutA,
typename B2bConv2d::ElementB, typename B2bConv2d::ElementB,
typename B2bConv2d::LayoutB, typename B2bConv2d::LayoutB,
typename B2bConv2d::ElementC, ElementAccumulator,
typename B2bConv2d::LayoutC, typename B2bConv2d::LayoutC,
ElementCompute, ElementAccumulator,
ElementAccumulator ElementAccumulator
>( >(
kConvolutionalOperator, kConvolutionalOperator,
problem_size_0, problem_size_0,
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0.device_ref(), tensor_B0.device_ref(),
tensor_C0.device_ref(), tensor_Z0_reference.device_ref(),
tensor_Z0_reference.device_ref(),
ElementAccumulator(1), // intermediate alpha = 1
ElementAccumulator(0) // beta = 0
);
cutlass::reference::device::TensorScaleBiasConv2d<
ElementAccumulator,
typename B2bConv2d::ElementC,
typename B2bConv2d::LayoutC,
ElementCompute,
typename B2bConv2d::LayoutScaleBias
>(
problem_size_0,
tensor_Z0_reference.device_ref(),
tensor_D0_reference.device_ref(), tensor_D0_reference.device_ref(),
alpha0, alpha0,
beta0,
nullptr, // stream
tensor_Scale0.device_ref(), tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()); tensor_Bias0.device_ref()
);
if(relu) { if(relu) {
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());

View File

@ -44,6 +44,7 @@
#include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h" #include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
#include "helper.h" #include "helper.h"
#define CHECK_GT(val1, val2) \ #define CHECK_GT(val1, val2) \
@ -68,6 +69,7 @@ struct B2bNonFusedGemmRun
cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C; cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Bias;
uint64_t seed; uint64_t seed;
// //
@ -78,9 +80,10 @@ struct B2bNonFusedGemmRun
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080 uint64_t seed_ = 2080
): ):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view /// Helper to initialize a tensor view
template <typename Element, typename Layout> template <typename Element, typename Layout>
@ -97,7 +100,7 @@ struct B2bNonFusedGemmRun
else if (dist_kind == cutlass::Distribution::Identity) { else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view); cutlass::reference::host::TensorFillIdentity(view);
} }
else if (dist_kind == cutlass::Distribution::Gaussian) { else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
@ -106,9 +109,14 @@ struct B2bNonFusedGemmRun
cutlass::reference::host::BlockFillSequential( cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity()); view.data(), view.capacity());
} }
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else { else {
// TODO: Implement the rest
std::cerr << "Not implemented\n"; std::cerr << "Not implemented\n";
return false; return false;
} }
@ -147,6 +155,10 @@ struct B2bNonFusedGemmRun
typename Gemm0::ElementC, typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
ElementCompute,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename Gemm0::ElementC, typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
@ -163,6 +175,10 @@ struct B2bNonFusedGemmRun
typename Gemm1::ElementC, typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
ElementCompute,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename Gemm1::ElementC, typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
@ -175,8 +191,10 @@ struct B2bNonFusedGemmRun
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
cutlass::reference::host::TensorFill( cutlass::reference::host::TensorFill(
tensor_D0.host_view()); tensor_D0.host_view());
@ -190,9 +208,11 @@ struct B2bNonFusedGemmRun
tensor_A0.sync_device(); tensor_A0.sync_device();
tensor_B0.sync_device(); tensor_B0.sync_device();
tensor_C0.sync_device(); tensor_C0.sync_device();
tensor_Bias0.sync_device();
tensor_D0.sync_device(); tensor_D0.sync_device();
tensor_B1.sync_device(); tensor_B1.sync_device();
tensor_C1.sync_device(); tensor_C1.sync_device();
tensor_Bias1.sync_device();
tensor_D1.sync_device(); tensor_D1.sync_device();
reference_D0.sync_device(); reference_D0.sync_device();
reference_D1.sync_device(); reference_D1.sync_device();
@ -205,7 +225,7 @@ struct B2bNonFusedGemmRun
problem_size_0, problem_size_0,
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0.device_ref(), tensor_B0.device_ref(),
tensor_C0.device_ref(), {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
tensor_D0.device_ref(), tensor_D0.device_ref(),
{alpha0, beta0} {alpha0, beta0}
}; };
@ -214,7 +234,7 @@ struct B2bNonFusedGemmRun
problem_size_1, problem_size_1,
tensor_D0.device_ref(), tensor_D0.device_ref(),
tensor_B1.device_ref(), tensor_B1.device_ref(),
tensor_C1.device_ref(), {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
tensor_D1.device_ref(), tensor_D1.device_ref(),
{alpha1, beta1} {alpha1, beta1}
}; };
@ -241,7 +261,6 @@ struct B2bNonFusedGemmRun
// //
// Run the GEMM // Run the GEMM
// //
cudaEvent_t start, stop1, stop2; cudaEvent_t start, stop1, stop2;
cudaEventCreate(&start); cudaEventCreate(&start);
cudaEventCreate(&stop1); cudaEventCreate(&stop1);
@ -256,7 +275,6 @@ struct B2bNonFusedGemmRun
} }
cudaEventRecord(stop1); cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) { for(int i = 0; i < runs; i++) {
status = gemm_op_1(); status = gemm_op_1();
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
@ -298,7 +316,7 @@ struct B2bNonFusedGemmRun
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0.device_ref(), tensor_B0.device_ref(),
beta0, beta0,
tensor_C0.device_ref(), {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref() reference_D0.device_ref()
); );
@ -312,7 +330,7 @@ struct B2bNonFusedGemmRun
reference_D0.device_ref(), reference_D0.device_ref(),
tensor_B1.device_ref(), tensor_B1.device_ref(),
beta1, beta1,
tensor_C1.device_ref(), {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref() reference_D1.device_ref()
); );
@ -325,7 +343,6 @@ struct B2bNonFusedGemmRun
reference_D0.sync_host(); reference_D0.sync_host();
reference_D1.sync_host(); reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
@ -349,13 +366,14 @@ struct B2bNonFusedGemmRun
<< "A0 =\n" << tensor_A0.host_view() << "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view() << "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view() << "\nC0 =\n" << tensor_C0.host_view()
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
<< "\nD0 =\n" << tensor_D0.host_view() << "\nD0 =\n" << tensor_D0.host_view()
<< "\nB1 =\n" << tensor_B1.host_view() << "\nB1 =\n" << tensor_B1.host_view()
<< "\nC1 =\n" << tensor_C1.host_view() << "\nC1 =\n" << tensor_C1.host_view()
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
<< "\n\nReference =\n" << reference_D1.host_view() << "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view(); << "\nComputed =\n" << tensor_D1.host_view();
} }
return passed; return passed;
} }
}; };
@ -372,6 +390,8 @@ struct B2bFusedGemmRun
cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C; cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Scale;
cutlass::Distribution::Kind init_Bias;
uint64_t seed; uint64_t seed;
// //
@ -382,9 +402,12 @@ struct B2bFusedGemmRun
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080 uint64_t seed_ = 2080
): ):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } init_A(init_A_), init_B(init_B_), init_C(init_C_),
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view /// Helper to initialize a tensor view
template <typename Element, typename Layout> template <typename Element, typename Layout>
@ -410,9 +433,14 @@ struct B2bFusedGemmRun
cutlass::reference::host::BlockFillSequential( cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity()); view.data(), view.capacity());
} }
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else { else {
// TODO: Implement the rest
std::cerr << "Not implemented\n"; std::cerr << "Not implemented\n";
return false; return false;
} }
@ -451,6 +479,21 @@ struct B2bFusedGemmRun
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()});
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
@ -463,6 +506,10 @@ struct B2bFusedGemmRun
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
ElementCompute,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
@ -475,21 +522,29 @@ struct B2bFusedGemmRun
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
if(alpha0 == ElementCompute(0)) //per-channel scale
CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
cutlass::reference::host::TensorFill( cutlass::reference::host::TensorFill(
tensor_D1.host_view()); tensor_D1.host_view());
cutlass::reference::host::TensorFill( cutlass::reference::host::TensorFill(
reference_D0.host_view()); reference_D0.host_view());
cutlass::reference::host::TensorFill( cutlass::reference::host::TensorFill(
reference_D1.host_view()); reference_D1.host_view());
tensor_A0.sync_device(); tensor_A0.sync_device();
tensor_B0.sync_device(); tensor_B0.sync_device();
tensor_C0.sync_device(); tensor_C0.sync_device();
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.sync_device();
tensor_Bias0.sync_device();
tensor_B1.sync_device(); tensor_B1.sync_device();
tensor_C1.sync_device(); tensor_C1.sync_device();
tensor_Bias1.sync_device();
tensor_D1.sync_device(); tensor_D1.sync_device();
reference_D0.sync_device(); reference_D0.sync_device();
reference_D1.sync_device(); reference_D1.sync_device();
@ -504,8 +559,10 @@ struct B2bFusedGemmRun
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0.device_ref(), tensor_B0.device_ref(),
tensor_C0.device_ref(), tensor_C0.device_ref(),
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref(),
tensor_B1.device_ref(), tensor_B1.device_ref(),
tensor_C1.device_ref(), {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(), tensor_D1.device_ref(),
{alpha0, beta0}, {alpha0, beta0},
{alpha1, beta1}, {alpha1, beta1},
@ -524,7 +581,6 @@ struct B2bFusedGemmRun
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl; << " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
} }
status = b2b_gemm_op.initialize(arguments); status = b2b_gemm_op.initialize(arguments);
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
@ -561,21 +617,42 @@ struct B2bFusedGemmRun
// //
// Verify // Verify
// //
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm< cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA, typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB, typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator> ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_0, reference_gemm_1; reference_gemm_1;
reference_gemm_0( reference_gemm_0(
problem_size_0, problem_size_0,
alpha0, ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0.device_ref(), tensor_B0.device_ref(),
beta0, ElementAccumulator(0), //beta = 0
tensor_C0.device_ref(), reference_Z0.device_ref(),
reference_D0.device_ref() reference_Z0.device_ref(),
ElementAccumulator(0)
);
cutlass::reference::device::TensorScaleBiasGemm<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
problem_size_0,
reference_Z0.device_ref(),
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()
); );
if(relu) { if(relu) {
@ -588,18 +665,15 @@ struct B2bFusedGemmRun
reference_D0.device_ref(), reference_D0.device_ref(),
tensor_B1.device_ref(), tensor_B1.device_ref(),
beta1, beta1,
tensor_C1.device_ref(), {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref() reference_D1.device_ref()
); );
if(relu) { if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view()); cutlass::reference::device::TensorReLu(reference_D1.device_view());
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
reference_D0.sync_host(); reference_D0.sync_host();
reference_D1.sync_host(); reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
@ -610,7 +684,8 @@ struct B2bFusedGemmRun
tensor_D1.host_view()); tensor_D1.host_view());
CHECK_TRUE(passed); CHECK_TRUE(passed);
if (!passed) { if (!passed)
{
std::stringstream fname; std::stringstream fname;
@ -623,12 +698,14 @@ struct B2bFusedGemmRun
<< "A0 =\n" << tensor_A0.host_view() << "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view() << "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view() << "\nC0 =\n" << tensor_C0.host_view()
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
<< "\nB1 =\n" << tensor_B1.host_view() << "\nB1 =\n" << tensor_B1.host_view()
<< "\nC1 =\n" << tensor_C1.host_view() << "\nC1 =\n" << tensor_C1.host_view()
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
<< "\n\nReference =\n" << reference_D1.host_view() << "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view(); << "\nComputed =\n" << tensor_D1.host_view();
} }
return passed; return passed;
} }

View File

@ -55,6 +55,7 @@
#include "cutlass/core_io.h" #include "cutlass/core_io.h"
#include "cutlass/util/tensor_view_io.h" #include "cutlass/util/tensor_view_io.h"
#include "reference/device/tensor_scale_bias.h"
#include "helper.h" #include "helper.h"
#define CHECK_GT(val1, val2) \ #define CHECK_GT(val1, val2) \
@ -91,14 +92,14 @@ public:
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0; cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0_reordered; cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0_reordered;
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0; cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0; cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_Bias0;
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed; cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference; cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1; cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1_reordered; cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1_reordered;
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1; cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1; cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d0::LayoutC> tensor_Bias1;
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed; cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference; cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
@ -379,11 +380,13 @@ public:
<< "\nB0:\n" << tensor_B0.host_view() << "\n" << "\nB0:\n" << tensor_B0.host_view() << "\n"
<< "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
<< "\nC0:\n" << tensor_C0.host_view() << "\n" << "\nC0:\n" << tensor_C0.host_view() << "\n"
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
<< "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
<< "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
<< "\nB1:\n" << tensor_B1.host_view() << "\n" << "\nB1:\n" << tensor_B1.host_view() << "\n"
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
<< "\nC1:\n" << tensor_C1.host_view() << "\n" << "\nC1:\n" << tensor_C1.host_view() << "\n"
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
<< "\nD1 computed:\n" << tensor_D1_computed.host_view(); << "\nD1 computed:\n" << tensor_D1_computed.host_view();
@ -421,12 +424,13 @@ public:
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0; cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0; cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0; cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference; cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1; cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1_reordered; cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1_reordered;
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1; cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1; cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_Bias1;
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed; cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference; cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
@ -503,6 +507,7 @@ public:
if(alpha0 == ElementCompute(0)) //per-channel scale if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.K}); tensor_Scale0.resize({1, problem_size_0.K});
tensor_Bias0.resize({1, problem_size_0.K}); tensor_Bias0.resize({1, problem_size_0.K});
tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
@ -632,23 +637,36 @@ public:
typename B2bConv2d::LayoutA, typename B2bConv2d::LayoutA,
typename B2bConv2d::ElementB, typename B2bConv2d::ElementB,
typename B2bConv2d::LayoutB, typename B2bConv2d::LayoutB,
typename B2bConv2d::ElementC,
typename B2bConv2d::LayoutC,
ElementCompute,
ElementAccumulator, ElementAccumulator,
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute> typename B2bConv2d::LayoutC,
ElementAccumulator,
ElementAccumulator
>( >(
kConvolutionalOperator, kConvolutionalOperator,
problem_size_0, problem_size_0,
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0.device_ref(), tensor_B0.device_ref(),
tensor_C0.device_ref(), tensor_Z0_reference.device_ref(),
tensor_Z0_reference.device_ref(),
ElementAccumulator(1), // intermediate alpha = 1
ElementAccumulator(0) // beta = 0
);
cutlass::reference::device::TensorScaleBiasConv2d<
ElementAccumulator,
typename B2bConv2d::ElementC,
typename B2bConv2d::LayoutC,
ElementCompute,
typename B2bConv2d::LayoutScaleBias,
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
>(
problem_size_0,
tensor_Z0_reference.device_ref(),
tensor_D0_reference.device_ref(), tensor_D0_reference.device_ref(),
alpha0, alpha0,
beta0,
nullptr, // stream
tensor_Scale0.device_ref(), tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()); tensor_Bias0.device_ref()
);
if(relu) { if(relu) {
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
@ -716,6 +734,7 @@ public:
<< "\nB1:\n" << tensor_B1.host_view() << "\n" << "\nB1:\n" << tensor_B1.host_view() << "\n"
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
<< "\nC1:\n" << tensor_C1.host_view() << "\n" << "\nC1:\n" << tensor_C1.host_view() << "\n"
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
<< "\nD1 computed:\n" << tensor_D1_computed.host_view(); << "\nD1 computed:\n" << tensor_D1_computed.host_view();

View File

@ -28,7 +28,6 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *
**************************************************************************************************/ **************************************************************************************************/
#pragma once #pragma once
#include <iostream> #include <iostream>
@ -46,6 +45,7 @@
#include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h" #include "cutlass/util/reference/device/tensor_relu.h"
#include "reference/device/tensor_scale_bias.h"
#include "helper.h" #include "helper.h"
#define CHECK_GT(val1, val2) \ #define CHECK_GT(val1, val2) \
@ -68,6 +68,7 @@ struct B2bInterleavedNonFusedGemmRun
cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C; cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Bias;
uint64_t seed; uint64_t seed;
// //
@ -78,9 +79,10 @@ struct B2bInterleavedNonFusedGemmRun
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080 uint64_t seed_ = 2080
): ):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view /// Helper to initialize a tensor view
template <typename Element, typename Layout> template <typename Element, typename Layout>
@ -97,14 +99,23 @@ struct B2bInterleavedNonFusedGemmRun
else if (dist_kind == cutlass::Distribution::Identity) { else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view); cutlass::reference::host::TensorFillIdentity(view);
} }
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) { else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential( cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity()); view.data(), view.capacity());
} }
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else { else {
// TODO: Implement the rest
std::cerr << "Not implemented\n"; std::cerr << "Not implemented\n";
return false; return false;
} }
@ -147,6 +158,10 @@ struct B2bInterleavedNonFusedGemmRun
typename Gemm0::ElementC, typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename Gemm0::ElementC, typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
@ -167,6 +182,10 @@ struct B2bInterleavedNonFusedGemmRun
typename Gemm1::ElementC, typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename Gemm1::ElementC, typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
@ -179,8 +198,10 @@ struct B2bInterleavedNonFusedGemmRun
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
//Reorder B0 and B1 //Reorder B0 and B1
cutlass::reorder_column<InterleavedK_>( cutlass::reorder_column<InterleavedK_>(
@ -201,10 +222,12 @@ struct B2bInterleavedNonFusedGemmRun
tensor_B0.sync_device(); tensor_B0.sync_device();
tensor_B0_reordered.sync_device(); tensor_B0_reordered.sync_device();
tensor_C0.sync_device(); tensor_C0.sync_device();
tensor_Bias0.sync_device();
tensor_D0.sync_device(); tensor_D0.sync_device();
tensor_B1.sync_device(); tensor_B1.sync_device();
tensor_B1_reordered.sync_device(); tensor_B1_reordered.sync_device();
tensor_C1.sync_device(); tensor_C1.sync_device();
tensor_Bias1.sync_device();
tensor_D1.sync_device(); tensor_D1.sync_device();
reference_D0.sync_device(); reference_D0.sync_device();
reference_D1.sync_device(); reference_D1.sync_device();
@ -217,7 +240,7 @@ struct B2bInterleavedNonFusedGemmRun
problem_size_0, problem_size_0,
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0_reordered.device_ref(), tensor_B0_reordered.device_ref(),
tensor_C0.device_ref(), {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
tensor_D0.device_ref(), tensor_D0.device_ref(),
{alpha0, beta0} {alpha0, beta0}
}; };
@ -226,7 +249,7 @@ struct B2bInterleavedNonFusedGemmRun
problem_size_1, problem_size_1,
tensor_D0.device_ref(), tensor_D0.device_ref(),
tensor_B1_reordered.device_ref(), tensor_B1_reordered.device_ref(),
tensor_C1.device_ref(), {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
tensor_D1.device_ref(), tensor_D1.device_ref(),
{alpha1, beta1} {alpha1, beta1}
}; };
@ -265,8 +288,7 @@ struct B2bInterleavedNonFusedGemmRun
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
} }
cudaEventRecord(stop1); cudaEventRecord(stop1);
for(int i = 0; i < runs; i++) { for(int i = 0; i < runs; i++) {
status = gemm_op_1(); status = gemm_op_1();
@ -286,7 +308,6 @@ struct B2bInterleavedNonFusedGemmRun
tensor_D0.sync_host(); tensor_D0.sync_host();
tensor_D1.sync_host(); tensor_D1.sync_host();
bool passed = false;
// //
// Verify // Verify
// //
@ -310,7 +331,7 @@ struct B2bInterleavedNonFusedGemmRun
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0.device_ref(), tensor_B0.device_ref(),
beta0, beta0,
tensor_C0.device_ref(), {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
reference_D0.device_ref() reference_D0.device_ref()
); );
@ -323,8 +344,8 @@ struct B2bInterleavedNonFusedGemmRun
alpha1, alpha1,
reference_D0.device_ref(), reference_D0.device_ref(),
tensor_B1.device_ref(), tensor_B1.device_ref(),
beta1, beta1,
tensor_C1.device_ref(), {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
reference_D1.device_ref() reference_D1.device_ref()
); );
@ -332,6 +353,7 @@ struct B2bInterleavedNonFusedGemmRun
cutlass::reference::device::TensorReLu(reference_D1.device_view()); cutlass::reference::device::TensorReLu(reference_D1.device_view());
} }
// Wait for kernels to finish
cudaDeviceSynchronize(); cudaDeviceSynchronize();
reference_D0.sync_host(); reference_D0.sync_host();
reference_D1.sync_host(); reference_D1.sync_host();
@ -341,7 +363,7 @@ struct B2bInterleavedNonFusedGemmRun
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
passed = cutlass::reference::host::TensorEquals( bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(), reference_D1.host_view(),
tensor_D1.host_view()); tensor_D1.host_view());
@ -360,10 +382,12 @@ struct B2bInterleavedNonFusedGemmRun
<< "\nB0 =\n" << tensor_B0.host_view() << "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view() << "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
<< "\nC0 =\n" << tensor_C0.host_view() << "\nC0 =\n" << tensor_C0.host_view()
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
<< "\nD0 =\n" << tensor_D0.host_view() << "\nD0 =\n" << tensor_D0.host_view()
<< "\nB1 =\n" << tensor_B1.host_view() << "\nB1 =\n" << tensor_B1.host_view()
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view() << "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
<< "\nC1 =\n" << tensor_C1.host_view() << "\nC1 =\n" << tensor_C1.host_view()
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
<< "\n\nReference =\n" << reference_D1.host_view() << "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view(); << "\nComputed =\n" << tensor_D1.host_view();
} }
@ -383,6 +407,8 @@ struct B2bInterleavedFusedGemmRun
cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C; cutlass::Distribution::Kind init_C;
cutlass::Distribution::Kind init_Scale;
cutlass::Distribution::Kind init_Bias;
uint64_t seed; uint64_t seed;
// //
@ -393,9 +419,12 @@ struct B2bInterleavedFusedGemmRun
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080 uint64_t seed_ = 2080
): ):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } init_A(init_A_), init_B(init_B_), init_C(init_C_),
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
/// Helper to initialize a tensor view /// Helper to initialize a tensor view
template <typename Element, typename Layout> template <typename Element, typename Layout>
@ -413,13 +442,22 @@ struct B2bInterleavedFusedGemmRun
cutlass::reference::host::TensorFillIdentity(view); cutlass::reference::host::TensorFillIdentity(view);
} }
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) { else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential( cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity()); view.data(), view.capacity());
} }
else if (dist_kind == cutlass::Distribution::AllZeros) {
cutlass::reference::host::TensorFill(view, Element(0));
}
else if (dist_kind == cutlass::Distribution::AllOnes) {
cutlass::reference::host::TensorFill(view, Element(1));
}
else { else {
// TODO: Implement the rest
std::cerr << "Not implemented\n"; std::cerr << "Not implemented\n";
return false; return false;
} }
@ -437,7 +475,7 @@ struct B2bInterleavedFusedGemmRun
ElementCompute alpha0 = ElementCompute(1), ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0), ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1), ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0), ElementCompute beta1 = ElementCompute(0),
bool relu = true, bool relu = true,
int warm_ups = 1, int warm_ups = 1,
int runs = 100) { int runs = 100) {
@ -462,6 +500,21 @@ struct B2bInterleavedFusedGemmRun
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.resize({1, problem_size_0.n()});
cutlass::HostTensor<
typename B2bGemm::ElementScaleBias,
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
cutlass::HostTensor<
ElementAccumulator,
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
@ -478,6 +531,10 @@ struct B2bInterleavedFusedGemmRun
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
cutlass::HostTensor< cutlass::HostTensor<
typename B2bGemm::ElementC, typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
@ -490,8 +547,12 @@ struct B2bInterleavedFusedGemmRun
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
if(alpha0 == ElementCompute(0)) //per-channel scale
CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
//Reorder B0 //Reorder B0
cutlass::reorder_column<16>( cutlass::reorder_column<16>(
@ -510,9 +571,13 @@ struct B2bInterleavedFusedGemmRun
tensor_B0.sync_device(); tensor_B0.sync_device();
tensor_B0_reordered.sync_device(); tensor_B0_reordered.sync_device();
tensor_C0.sync_device(); tensor_C0.sync_device();
if(alpha0 == ElementCompute(0)) //per-channel scale
tensor_Scale0.sync_device();
tensor_Bias0.sync_device();
tensor_B1.sync_device(); tensor_B1.sync_device();
tensor_B1_reordered.sync_device(); tensor_B1_reordered.sync_device();
tensor_C1.sync_device(); tensor_C1.sync_device();
tensor_Bias1.sync_device();
tensor_D1.sync_device(); tensor_D1.sync_device();
reference_D0.sync_device(); reference_D0.sync_device();
reference_D1.sync_device(); reference_D1.sync_device();
@ -527,12 +592,13 @@ struct B2bInterleavedFusedGemmRun
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0_reordered.device_ref(), tensor_B0_reordered.device_ref(),
tensor_C0.device_ref(), tensor_C0.device_ref(),
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref(),
tensor_B1_reordered.device_ref(), tensor_B1_reordered.device_ref(),
tensor_C1.device_ref(), {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
tensor_D1.device_ref(), tensor_D1.device_ref(),
{alpha0, beta0}, {alpha0, beta0},
{alpha1, beta1}, {alpha1, beta1},
1, /*threadblock_swizzle_k_tile*/
}; };
B2bGemm b2b_gemm_op; B2bGemm b2b_gemm_op;
@ -581,25 +647,45 @@ struct B2bInterleavedFusedGemmRun
tensor_D1.sync_host(); tensor_D1.sync_host();
bool passed = false;
// //
// Verify // Verify
// //
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
ElementAccumulator, typename B2bGemm::LayoutC,
ElementAccumulator, ElementAccumulator>
reference_gemm_0;
cutlass::reference::device::Gemm< cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA, typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB, typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator> ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_0, reference_gemm_1; reference_gemm_1;
reference_gemm_0( reference_gemm_0(
problem_size_0, problem_size_0,
alpha0, ElementAccumulator(1), //intermediate alpha=1
tensor_A0.device_ref(), tensor_A0.device_ref(),
tensor_B0.device_ref(), tensor_B0.device_ref(),
beta0, ElementAccumulator(0), //beta = 0
tensor_C0.device_ref(), reference_Z0.device_ref(),
reference_D0.device_ref() reference_Z0.device_ref(),
ElementAccumulator(0)
);
cutlass::reference::device::TensorScaleBiasGemm<
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
ElementCompute, typename B2bGemm::LayoutScaleBias
> (
problem_size_0,
reference_Z0.device_ref(),
reference_D0.device_ref(),
alpha0,
tensor_Scale0.device_ref(),
tensor_Bias0.device_ref()
); );
if(relu) { if(relu) {
@ -612,29 +698,27 @@ struct B2bInterleavedFusedGemmRun
reference_D0.device_ref(), reference_D0.device_ref(),
tensor_B1.device_ref(), tensor_B1.device_ref(),
beta1, beta1,
tensor_C1.device_ref(), {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
reference_D1.device_ref() reference_D1.device_ref()
); );
if(relu) { if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view()); cutlass::reference::device::TensorReLu(reference_D1.device_view());
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
reference_D0.sync_host(); reference_D0.sync_host();
reference_D1.sync_host(); reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
passed = cutlass::reference::host::TensorEquals( bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(), reference_D1.host_view(),
tensor_D1.host_view()); tensor_D1.host_view());
CHECK_TRUE(passed); CHECK_TRUE(passed);
if (!passed) { if (!passed)
{
std::stringstream fname; std::stringstream fname;
@ -648,9 +732,12 @@ struct B2bInterleavedFusedGemmRun
<< "\nB0 =\n" << tensor_B0.host_view() << "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view() << "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
<< "\nC0 =\n" << tensor_C0.host_view() << "\nC0 =\n" << tensor_C0.host_view()
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
<< "\nB1 =\n" << tensor_B1.host_view() << "\nB1 =\n" << tensor_B1.host_view()
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view() << "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
<< "\nC1 =\n" << tensor_C1.host_view() << "\nC1 =\n" << tensor_C1.host_view()
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
<< "\n\nReference =\n" << reference_D1.host_view() << "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view(); << "\nComputed =\n" << tensor_D1.host_view();
} }

View File

@ -158,6 +158,10 @@ class B2bGemm {
static ComplexTransform const kTransformA = ComplexTransform::kNone; static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone; static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Derived types
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
using LayoutScaleBias = layout::RowMajor;
/// Define the kernel /// Define the kernel
using B2bGemmKernel = typename kernel::DefaultB2bGemm< using B2bGemmKernel = typename kernel::DefaultB2bGemm<
ElementA, ElementA,
@ -197,6 +201,8 @@ class B2bGemm {
TensorRef<ElementA const, LayoutA> ref_A0; TensorRef<ElementA const, LayoutA> ref_A0;
TensorRef<ElementB const, LayoutB> ref_B0; TensorRef<ElementB const, LayoutB> ref_B0;
TensorRef<ElementC const, LayoutC> ref_C0; TensorRef<ElementC const, LayoutC> ref_C0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
TensorRef<ElementB const, LayoutB> ref_B1; TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1; TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1; TensorRef<ElementC, LayoutC> ref_D1;
@ -222,6 +228,8 @@ class B2bGemm {
TensorRef<ElementA const, LayoutA> ref_A0_, TensorRef<ElementA const, LayoutA> ref_A0_,
TensorRef<ElementB const, LayoutB> ref_B0_, TensorRef<ElementB const, LayoutB> ref_B0_,
TensorRef<ElementC const, LayoutC> ref_C0_, TensorRef<ElementC const, LayoutC> ref_C0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
TensorRef<ElementB const, LayoutB> ref_B1_, TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_, TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_, TensorRef<ElementC, LayoutC> ref_D1_,
@ -236,6 +244,8 @@ class B2bGemm {
ref_A0(ref_A0_), ref_A0(ref_A0_),
ref_B0(ref_B0_), ref_B0(ref_B0_),
ref_C0(ref_C0_), ref_C0(ref_C0_),
ref_Scale0(ref_Scale0_),
ref_Bias0(ref_Bias0_),
ref_B1(ref_B1_), ref_B1(ref_B1_),
ref_C1(ref_C1_), ref_C1(ref_C1_),
ref_D1(ref_D1_), ref_D1(ref_D1_),
@ -348,6 +358,8 @@ public:
args.ref_A0.non_const_ref(), args.ref_A0.non_const_ref(),
args.ref_B0.non_const_ref(), args.ref_B0.non_const_ref(),
args.ref_C0.non_const_ref(), args.ref_C0.non_const_ref(),
args.ref_Scale0.non_const_ref(),
args.ref_Bias0.non_const_ref(),
args.ref_B1.non_const_ref(), args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(), args.ref_C1.non_const_ref(),
args.ref_D1, args.ref_D1,
@ -368,12 +380,14 @@ public:
} }
} }
params_.ref_A0.reset(args.ref_A.non_const_ref().data()); params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
params_.ref_B0.reset(args.ref_B.non_const_ref().data()); params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
params_.ref_C0.reset(args.ref_C.non_const_ref().data()); params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
params_.ref_B1.reset(args.ref_B.non_const_ref().data()); params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data());
params_.ref_C1.reset(args.ref_C.non_const_ref().data()); params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data());
params_.ref_D1.reset(args.ref_D.data()); params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
params_.ref_D1.reset(args.ref_D1.data());
params_.output_op_0 = args.epilogue0; params_.output_op_0 = args.epilogue0;
params_.output_op_1 = args.epilogue1; params_.output_op_1 = args.epilogue1;
params_.semaphore = static_cast<int *>(workspace); params_.semaphore = static_cast<int *>(workspace);

View File

@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //use beta for bias ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2, 2,
@ -151,14 +151,15 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //use beta for bias ElementCompute beta1 = ElementCompute(1); //use beta for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using EpilogueOutputOp0 = using EpilogueOutputOp0 =

View File

@ -68,13 +68,13 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2, 2,
@ -118,7 +118,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2, 2,
@ -151,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() {
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
@ -176,7 +177,7 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() {
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;

View File

@ -69,14 +69,14 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
@ -94,7 +94,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, 3,
@ -118,7 +118,8 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
ElementC, ElementC,
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, 3,
@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() {
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() {
ElementC, ElementC,
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop< using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<

View File

@ -69,13 +69,13 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -94,7 +94,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, 3,
@ -118,7 +118,8 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
ElementC, ElementC,
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, 3,
@ -151,9 +152,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() {
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() {
ElementC, ElementC,
128 / cutlass::sizeof_bits<ElementC>::value, 128 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
const bool SmemAccumulator = true; const bool SmemAccumulator = true;

View File

@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2, 2,
@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
ElementC, ElementC,
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2, 2,
@ -151,14 +152,15 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using EpilogueOutputOp0 = using EpilogueOutputOp0 =
@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
ElementC, ElementC,
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;

View File

@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2, 2,
@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
ElementC, ElementC,
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2, 2,
@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
ElementC, ElementC,
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;

View File

@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop< using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, 3,
@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
ElementC, ElementC,
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, 3,
@ -151,14 +152,15 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() {
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>; using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 = using EpilogueOutputOp0 =
@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() {
ElementC, ElementC,
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;

View File

@ -68,13 +68,13 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, 3,
@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
ElementC, ElementC,
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, 3,
@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() {
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(1); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() {
ElementC, ElementC,
64 / cutlass::sizeof_bits<ElementC>::value, 64 / cutlass::sizeof_bits<ElementC>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
const bool SmemAccumulator = true; const bool SmemAccumulator = true;

View File

@ -55,10 +55,10 @@ bool run_nonfused_gemm_f16() {
using ElementAccumulator = cutlass::half_t; using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16() {
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2 2
@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16() {
ElementOutput, ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2 2
@ -131,10 +132,11 @@ bool run_fused_gemm_f16_rf_res() {
using ElementAccumulator = cutlass::half_t; using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
@ -156,7 +158,8 @@ bool run_fused_gemm_f16_rf_res() {
ElementOutput, ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
using B2bGemm = cutlass::gemm::device::B2bGemm< using B2bGemm = cutlass::gemm::device::B2bGemm<

View File

@ -55,14 +55,14 @@ bool run_nonfused_gemm_f16() {
using ElementAccumulator = cutlass::half_t; using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16() {
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2 2
@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16() {
ElementOutput, ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2 2
@ -130,10 +131,11 @@ bool run_fused_gemm_f16_shmem() {
using ElementAccumulator = cutlass::half_t; using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
@ -155,7 +157,8 @@ bool run_fused_gemm_f16_shmem() {
ElementOutput, ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;

View File

@ -55,15 +55,15 @@ bool run_nonfused_gemm_f16_sm80() {
using ElementAccumulator = cutlass::half_t; using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Gemm0 = cutlass::gemm::device::Gemm< using Gemm0 = cutlass::gemm::device::Gemm<
@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16_sm80() {
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3 3
@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16_sm80() {
ElementOutput, ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3 3
@ -130,15 +131,16 @@ bool run_fused_gemm_f16_sm80_rf_res() {
using ElementAccumulator = cutlass::half_t; using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); //Fused kernel has built-in bias, setting beta=0
ElementCompute alpha1 = ElementCompute(2); ElementCompute beta0 = ElementCompute(0);
ElementCompute beta1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>; using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using EpilogueOutputOp0 = using EpilogueOutputOp0 =
@ -155,11 +157,10 @@ bool run_fused_gemm_f16_sm80_rf_res() {
ElementOutput, ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
using B2bGemm = cutlass::gemm::device::B2bGemm< using B2bGemm = cutlass::gemm::device::B2bGemm<
cutlass::half_t, cutlass::half_t,
cutlass::layout::RowMajor, cutlass::layout::RowMajor,

View File

@ -55,10 +55,10 @@ bool run_nonfused_gemm_f16_sm80() {
using ElementAccumulator = cutlass::half_t; using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16_sm80() {
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3 3
@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16_sm80() {
ElementOutput, ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3 3
@ -130,10 +131,11 @@ bool run_fused_gemm_f16_sm80_shmem() {
using ElementAccumulator = cutlass::half_t; using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t; using ElementCompute = cutlass::half_t;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); //Fused kernel has built-in bias, setting beta=0
ElementCompute alpha1 = ElementCompute(2); ElementCompute beta0 = ElementCompute(0);
ElementCompute beta1 = ElementCompute(1); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
@ -155,7 +157,8 @@ bool run_fused_gemm_f16_sm80_shmem() {
ElementOutput, ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;

View File

@ -55,10 +55,10 @@ bool run_nonfused_gemm_s8() {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8() {
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2 2
@ -106,7 +106,8 @@ bool run_nonfused_gemm_s8() {
ElementOutput, ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2 2
@ -131,10 +132,11 @@ bool run_fused_gemm_s8_rf_res() {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
@ -156,7 +158,8 @@ bool run_fused_gemm_s8_rf_res() {
ElementOutput, ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
using B2bGemm = cutlass::gemm::device::B2bGemm< using B2bGemm = cutlass::gemm::device::B2bGemm<
@ -200,7 +203,7 @@ int main() {
&run_fused_gemm_s8_rf_res &run_fused_gemm_s8_rf_res
}; };
return testRun(75, funcs, "gemm f16 RF residency"); return testRun(75, funcs, "gemm int8 RF residency");
} }

View File

@ -55,15 +55,15 @@ bool run_nonfused_gemm_s8() {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using Gemm0 = cutlass::gemm::device::Gemm< using Gemm0 = cutlass::gemm::device::Gemm<
@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8() {
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2 2
@ -106,7 +106,8 @@ bool run_nonfused_gemm_s8() {
ElementOutput, ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2 2
@ -130,10 +131,11 @@ bool run_fused_gemm_s8_shmem() {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(1); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
@ -155,7 +157,8 @@ bool run_fused_gemm_s8_shmem() {
ElementOutput, ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
const bool SmemAccumulator = true; const bool SmemAccumulator = true;
@ -202,7 +205,7 @@ int main() {
&run_fused_gemm_s8_shmem &run_fused_gemm_s8_shmem
}; };
return testRun(75, funcs, "gemm s8 shmem staing"); return testRun(75, funcs, "gemm int8 shmem staing");
} }

View File

@ -55,15 +55,15 @@ bool run_nonfused_gemm_s8_sm80() {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using Gemm0 = cutlass::gemm::device::Gemm< using Gemm0 = cutlass::gemm::device::Gemm<
@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8_sm80() {
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3, 3,
@ -111,7 +111,7 @@ bool run_nonfused_gemm_s8_sm80() {
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3, 3,
@ -140,15 +140,16 @@ bool run_fused_gemm_s8_sm80_rf_res() {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>; using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using EpilogueOutputOp0 = using EpilogueOutputOp0 =
@ -166,7 +167,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
const bool SmemAccumulator = false; const bool SmemAccumulator = false;

View File

@ -55,14 +55,14 @@ bool run_nonfused_gemm_s8_sm80() {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8_sm80() {
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3, 3,
@ -111,7 +111,7 @@ bool run_nonfused_gemm_s8_sm80() {
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>, >,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3, 3,
@ -139,10 +139,11 @@ bool run_fused_gemm_s8_sm80_shmem() {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
ElementCompute alpha0 = ElementCompute(2); ElementCompute alpha0 = ElementCompute(1);
//Fused kernel has built-in bias, setting beta=0
ElementCompute beta0 = ElementCompute(0); ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2); ElementCompute alpha1 = ElementCompute(1);
ElementCompute beta1 = ElementCompute(0); ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>; using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
@ -165,7 +166,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
64 / cutlass::sizeof_bits<ElementOutput>::value, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator,
ElementCompute, ElementCompute,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling cutlass::epilogue::thread::ScaleType::NoBetaScaling
>; >;
const bool SmemAccumulator = true; const bool SmemAccumulator = true;

View File

@ -79,6 +79,8 @@ struct B2bGemm {
typename B2bMma::IteratorB0::TensorRef ref_B0; typename B2bMma::IteratorB0::TensorRef ref_B0;
typename Epilogue::OutputTileIterator::Params params_C0; typename Epilogue::OutputTileIterator::Params params_C0;
typename Epilogue::OutputTileIterator::TensorRef ref_C0; typename Epilogue::OutputTileIterator::TensorRef ref_C0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
typename B2bMma::IteratorB1::Params params_B1; typename B2bMma::IteratorB1::Params params_B1;
typename B2bMma::IteratorB1::TensorRef ref_B1; typename B2bMma::IteratorB1::TensorRef ref_B1;
typename Epilogue::OutputTileIterator::Params params_C1; typename Epilogue::OutputTileIterator::Params params_C1;
@ -109,6 +111,8 @@ struct B2bGemm {
typename B2bMma::IteratorA0::TensorRef ref_A0, typename B2bMma::IteratorA0::TensorRef ref_A0,
typename B2bMma::IteratorB0::TensorRef ref_B0, typename B2bMma::IteratorB0::TensorRef ref_B0,
typename Epilogue::OutputTileIterator::TensorRef ref_C0, typename Epilogue::OutputTileIterator::TensorRef ref_C0,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0,
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0,
typename B2bMma::IteratorB1::TensorRef ref_B1, typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1, typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1, typename Epilogue::OutputTileIterator::TensorRef ref_D1,
@ -126,6 +130,8 @@ struct B2bGemm {
ref_B0(ref_B0), ref_B0(ref_B0),
params_C0(ref_C0.layout()), params_C0(ref_C0.layout()),
ref_C0(ref_C0), ref_C0(ref_C0),
ref_Scale0(ref_Scale0),
ref_Bias0(ref_Bias0),
params_B1(ref_B1.layout()), params_B1(ref_B1.layout()),
ref_B1(ref_B1), ref_B1(ref_B1),
params_C1(ref_C1.layout()), params_C1(ref_C1.layout()),
@ -305,6 +311,29 @@ struct B2bGemm {
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32; int lane_idx = threadIdx.x % 32;
// Construct iterators to accumulator scale/bias vector
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
params.ref_Scale0.data(),
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
MatrixCoord(
0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
)
);
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
params.ref_Bias0.data(),
{1, params.problem_size_0.n()},
thread_idx,
warp_idx,
MatrixCoord(
0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
)
);
// //
// Main loop // Main loop
// //
@ -322,7 +351,8 @@ struct B2bGemm {
if (!kSplitKSerial || gemm_k_iterations_0 > 0) { if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
// Compute threadblock-scoped matrix multiply-add // Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0); b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
} }
// //

View File

@ -338,7 +338,7 @@ struct DefaultB2bConv2dFprop <
cutlass::transform::threadblock::VectorIterator< cutlass::transform::threadblock::VectorIterator<
cutlass::transform::threadblock::PredicatedVectorAccessIterator< cutlass::transform::threadblock::PredicatedVectorAccessIterator<
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>, cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kK>, cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
ElementScaleBias, LayoutScaleBias, kElementsPerAccess> ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
>; >;

View File

@ -0,0 +1,275 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines device-side elementwise operations on TensorView. Note, the operations defined
in this header are not specialized for any particular data layout and are therefore not
intended to offer the best possible performance. Rather, they are intended to be generic
reference implementations to support the CUTLASS unit tests.
*/
#pragma once
// Cutlass includes
#include "cutlass/cutlass.h"
#include "cutlass/tensor_view.h"
#include "cutlass/gemm/gemm.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace reference {
namespace device {
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace kernel {
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
typename ScalarType, ///< alpha Type
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
typename OutputTile,
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>
>
__global__ void TensorScaleBiasGemm(
gemm::GemmCoord problem_size,
TensorRefIn tensor_in, ///< input tensor
TensorRefOut tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
MatrixCoord output_coord(
MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow),
MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn)
);
// Update the output tensor
for (int j = 0; j < OutputTile::kRow; ++j) {
for (int i = 0; i < OutputTile::kColumn; ++i) {
MatrixCoord coord = output_coord + MatrixCoord(i, j);
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, coord.column()});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
bias = tensor_bias.at({0, coord.column()});
tensor_out.at(coord) = convert_op(
scale * ScalarType(tensor_in.at(coord)) + bias);
}
}
}
}
template <
typename TensorRefIn, ///< Input TensorRef Type
typename TensorRefOut, ///< Output TensorRef Type
typename ScalarType, ///< alpha Type
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
int kThreadM = 4, // shape of a thread's tile in the GEMM M dimension
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
int kCtaShapeM = 16, // shape of a threadblock in units of threads
int kCtaShapeN = 8 // shape of a threadblock in units of threads
>
__global__ void TensorScaleBiasConv2d(
conv::Conv2dProblemSize problem_size,
TensorRefIn tensor_in, ///< input tensor
TensorRefOut tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRefScalar tensor_scale, ///< scale tensor
TensorRefScalar tensor_bias ///< bias tensor
) {
ConvertOp convert_op;
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
int thread_n[kThreadM];
int thread_p[kThreadM];
int thread_q[kThreadM];
// Compute N, P, Q coordinates for each row of a thread's tile
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
int64_t npq = npq_start + m;
thread_n[m] = int(npq / PQ);
int64_t residual = npq % PQ;
thread_p[m] = int(residual / problem_size.Q);
thread_q[m] = int(residual % problem_size.Q);
}
// Write out the results
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < kThreadM; ++m) {
if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kThreadN; ++n) {
int thread_k = k_start + n;
if (thread_k < problem_size.K) {
ScalarType scale = alpha;
if(tensor_scale.good())
scale = tensor_scale.at({0, thread_k});
ScalarType bias = ScalarType(0);
if(tensor_bias.good())
bias = tensor_bias.at({0, thread_k});
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
scale * ScalarType(
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
) + bias);
}
}
}
}
}
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type
typename ElementOut, ///< Output Type
typename Layout, ///< Layout of input/output tensor
typename ScalarType, ///< alpha Type
typename LayoutScaleBias, ///< Layout of scale and bias
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
>
void TensorScaleBiasGemm(
gemm::GemmCoord problem_size,
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
TensorRef<ScalarType, LayoutScaleBias> tensor_bias ///< bias tensor
) {
using OutputTile = MatrixShape<4, 4>;
dim3 block(16, 8);
dim3 grid(
(problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
(problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
);
kernel::TensorScaleBiasGemm<
TensorRef<ElementIn, Layout>,
TensorRef<ElementOut, Layout>,
ScalarType,
TensorRef<ScalarType, LayoutScaleBias>,
OutputTile,
ConvertOp
><<< grid, block >>> (
problem_size,
tensor_in,
tensor_out,
alpha,
tensor_scale,
tensor_bias
);
}
/// Apply scale and bias on a tensor
template <
typename ElementIn, ///< Input Type
typename ElementOut, ///< Output Type
typename Layout, ///< Layout of input/output tensor
typename ScalarType, ///< alpha Type
typename LayoutScaleBias, ///< Layout of scale and bias
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
>
void TensorScaleBiasConv2d(
conv::Conv2dProblemSize problem_size,
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
ScalarType alpha, ///< alpha
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
TensorRef<ScalarType, LayoutScaleBias> tensor_bias ///< bias tensor
) {
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
dim3 block(kCtaShapeM, kCtaShapeN);
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
kernel::TensorScaleBiasConv2d<
TensorRef<ElementIn, Layout>,
TensorRef<ElementOut, Layout>,
ScalarType,
TensorRef<ScalarType, LayoutScaleBias>,
ConvertOp,
kThreadM,
kThreadN,
kCtaShapeM,
kCtaShapeN
><<< grid, block >>> (
problem_size,
tensor_in,
tensor_out,
alpha,
tensor_scale,
tensor_bias
);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace reference
} // namespace cutlass

View File

@ -745,7 +745,6 @@ public:
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B1_; ++this->warp_tile_iterator_B1_;
if (warp_mma_k > 0) if (warp_mma_k > 0)
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2], warp_transformed_frag_B1[warp_mma_k % 2],

View File

@ -82,6 +82,11 @@ template <
/// Iterates over the intermediate accumulator tile /// Iterates over the intermediate accumulator tile
// (concept::MmaTensorOpFragmentIterator) // (concept::MmaTensorOpFragmentIterator)
typename FragmentIteratorA1_, typename FragmentIteratorA1_,
/// Iterates over vectors of scale and bias vector in global memory
// (concept: VectorIterator)
typename IteratorAccumulatorScaleBias_,
/// WarpIterator to load Scale or Bias vector from threadblock fragment
typename FragmentIteratorA1ScaleBias_,
/// Iterates over tiles of B operand in global memory /// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | // (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator) // MaskedTileIterator)
@ -126,6 +131,10 @@ public:
using Shape1 = Shape1_; using Shape1 = Shape1_;
///< Iterates over intermediate accumulator tile ///< Iterates over intermediate accumulator tile
using FragmentIteratorA1 = FragmentIteratorA1_; using FragmentIteratorA1 = FragmentIteratorA1_;
///< Iterates over tiles of the scale and bias vectors in global memory
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
///< WarpIterator to load Scale or Bias vector from threadblock fragment
using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_;
///< Iterates over tiles of B operand in global memory ///< Iterates over tiles of B operand in global memory
using IteratorB1 = IteratorB1_; using IteratorB1 = IteratorB1_;
///< Policy describing tuning details ///< Policy describing tuning details
@ -140,6 +149,9 @@ public:
///< Epilogue after 1st Gemm ///< Epilogue after 1st Gemm
using OutputOp = OutputOp_; using OutputOp = OutputOp_;
static const bool PerChannelScale = (OutputOp::kScale ==
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
@ -154,6 +166,9 @@ public:
/// Warp-level Mma /// Warp-level Mma
using Operator0 = typename Policy0::Operator; using Operator0 = typename Policy0::Operator;
/// Fragment of Scale and Bias loaded from global memory
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
/// Fragment of accumulator tile /// Fragment of accumulator tile
using FragmentC1 = typename Policy1::Operator::FragmentC; using FragmentC1 = typename Policy1::Operator::FragmentC;
@ -217,6 +232,8 @@ public:
using WarpLoadedFragmentB0 = typename Operator0::FragmentB; using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
/// Warp Fragment of operand A1 loaded from accmulator tile /// Warp Fragment of operand A1 loaded from accmulator tile
using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment;
using WarpLoadedFragmentA1ScaleBias =
typename FragmentIteratorA1ScaleBias::Fragment;
using WarpLoadedFragmentB1 = typename Operator1::FragmentB; using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
@ -381,11 +398,15 @@ public:
int gemm_k_iterations_0, int gemm_k_iterations_0,
///< destination accumulator tile ///< destination accumulator tile
FragmentC1 &accum, FragmentC1 &accum,
///< iterator over A operand in global memory ///< iterator over A0 operand in global memory
IteratorA0 iterator_A0, IteratorA0 iterator_A0,
///< iterator over B operand in global memory ///< iterator over B0 operand in global memory
IteratorB0 iterator_B0, IteratorB0 iterator_B0,
///< iterator over B operand in global memory ///< iterator over A1 operand scale vector in global memory
IteratorAccumulatorScaleBias iterator_A1_scale,
///< iterator over A1 operand bias vector in global memory
IteratorAccumulatorScaleBias iterator_A1_bias,
///< iterator over B1 operand in global memory
IteratorB1 iterator_B1, IteratorB1 iterator_B1,
///< initial value of accumulator ///< initial value of accumulator
FragmentC0 const &src_accum, FragmentC0 const &src_accum,
@ -623,6 +644,20 @@ public:
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
FragmentIteratorA1 warp_tile_iterator_A1_(accum0); FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
FragmentA1ScaleBias tb_frag_A1_scale;
FragmentA1ScaleBias tb_frag_A1_bias;
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
if(PerChannelScale) {
tb_frag_A1_scale.clear();
iterator_A1_scale.load(tb_frag_A1_scale);
++iterator_A1_scale;
}
tb_frag_A1_bias.clear();
iterator_A1_bias.load(tb_frag_A1_bias);
++iterator_A1_bias;
// //
// Prologue // Prologue
@ -678,18 +713,29 @@ public:
// Pair of fragments used to overlap shared memory loads and math // Pair of fragments used to overlap shared memory loads and math
// instructions // instructions
WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2];
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2];
WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
Operator1 warp_mma1; Operator1 warp_mma1;
this->warp_tile_iterator_B1_.set_kgroup_index(0); if(PerChannelScale) {
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0); ++warp_tile_iterator_A1_scale_;
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); }
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]);
++warp_tile_iterator_A1_bias_;
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0],
warp_loaded_frag_A1_scale[0],
warp_loaded_frag_A1_bias[0],
output_op_0);
++warp_tile_iterator_A1_; ++warp_tile_iterator_A1_;
this->warp_tile_iterator_B1_.set_kgroup_index(0);
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
++this->warp_tile_iterator_B1_; ++this->warp_tile_iterator_B1_;
iterator_B1.clear_mask(gemm_k_iterations_1 == 0); iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
@ -717,15 +763,37 @@ public:
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
++warp_mma_k) { ++warp_mma_k) {
// Load threadblock-level scale/bias vector from global memory
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
if(PerChannelScale) {
tb_frag_A1_scale.clear();
iterator_A1_scale.load(tb_frag_A1_scale);
++iterator_A1_scale;
}
tb_frag_A1_bias.clear();
iterator_A1_bias.load(tb_frag_A1_bias);
++iterator_A1_bias;
}
// Load warp-level scale bias fragment from threadblock scale/bias vector
if(PerChannelScale) {
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
++warp_tile_iterator_A1_scale_;
}
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]);
++warp_tile_iterator_A1_bias_;
// Load warp-level tile from accumulator fragment
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2],
warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2],
output_op_0);
++warp_tile_iterator_A1_;
// Load warp-level tiles from shared memory, wrapping to k offset if // Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be. // this is the last group as the case may be.
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
++warp_tile_iterator_A1_;
++this->warp_tile_iterator_B1_; ++this->warp_tile_iterator_B1_;
if (warp_mma_k > 0) if (warp_mma_k > 0)

View File

@ -165,6 +165,9 @@ public:
/// Warp-level Mma /// Warp-level Mma
using Operator0 = typename Policy0::Operator; using Operator0 = typename Policy0::Operator;
/// Fragment of Scale and Bias loaded from global memory
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
/// Fragment of accumulator tile /// Fragment of accumulator tile
using FragmentC1 = typename Policy1::Operator::FragmentC; using FragmentC1 = typename Policy1::Operator::FragmentC;
@ -418,11 +421,15 @@ public:
int gemm_k_iterations_0, int gemm_k_iterations_0,
///< destination accumulator tile ///< destination accumulator tile
FragmentC1 &accum, FragmentC1 &accum,
///< iterator over A operand in global memory ///< iterator over A0 operand in global memory
IteratorA0 iterator_A0, IteratorA0 iterator_A0,
///< iterator over B operand in global memory ///< iterator over B0 operand in global memory
IteratorB0 iterator_B0, IteratorB0 iterator_B0,
///< iterator over B operand in global memory ///< iterator over A1 operand scale vector in global memory
IteratorAccumulatorScaleBias iterator_accum0_scale,
///< iterator over A1 operand bias vector in global memory
IteratorAccumulatorScaleBias iterator_accum0_bias,
///< iterator over B1 operand in global memory
IteratorB1 iterator_B1, IteratorB1 iterator_B1,
///< initial value of accumulator ///< initial value of accumulator
FragmentC0 const &src_accum, FragmentC0 const &src_accum,
@ -658,7 +665,7 @@ public:
/// Epilogue for the first Implicit Gemm /// Epilogue for the first Implicit Gemm
Epilogue0 epilogue0; Epilogue0 epilogue0;
epilogue0(output_op_0, smem_iterator_D0_, accum0); epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
__syncthreads(); __syncthreads();

View File

@ -76,6 +76,11 @@ template <
/// Iterates over the intermediate accumulator tile /// Iterates over the intermediate accumulator tile
// (concept::MmaTensorOpFragmentIterator) // (concept::MmaTensorOpFragmentIterator)
typename FragmentIteratorA1_, typename FragmentIteratorA1_,
/// Iterates over vectors of scale and bias vector in global memory
// (concept: VectorIterator)
typename IteratorAccumulatorScaleBias_,
/// FragmentIterator to load Scale or Bias vector from threadblock fragment
typename FragmentIteratorA1ScaleBias_,
/// Iterates over tiles of B operand in global memory /// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB1_, typename IteratorB1_,
@ -129,6 +134,9 @@ public:
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
using FragmentIteratorA1ScaleBias =
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
using Policy1 = Policy1_; ///< Policy describing tuning details using Policy1 = Policy1_; ///< Policy describing tuning details
@ -140,6 +148,9 @@ public:
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
static const bool PerChannelScale = (OutputOp::kScale ==
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
using TransformA0 = TransformA0_; using TransformA0 = TransformA0_;
using TransformB0 = TransformB0_; using TransformB0 = TransformB0_;
using TransformB1 = TransformB1_; using TransformB1 = TransformB1_;
@ -160,6 +171,9 @@ public:
/// Warp-level Mma /// Warp-level Mma
using Operator0 = typename Policy0::Operator; using Operator0 = typename Policy0::Operator;
/// Fragment of Scale and Bias loaded from global memory
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
/// Fragment of operand B loaded from global memory /// Fragment of operand B loaded from global memory
using FragmentB1 = typename IteratorB1::Fragment; using FragmentB1 = typename IteratorB1::Fragment;
@ -190,6 +204,9 @@ private:
using WarpFragmentB0 = typename Operator0::FragmentB; using WarpFragmentB0 = typename Operator0::FragmentB;
/// Warp Fragment of operand A1 loaded from accmulator tile /// Warp Fragment of operand A1 loaded from accmulator tile
using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
/// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment
using WarpFragmentA1ScaleBias =
typename FragmentIteratorA1ScaleBias::Fragment;
using WarpFragmentB1 = typename Operator1::FragmentB; using WarpFragmentB1 = typename Operator1::FragmentB;
protected: protected:
@ -248,6 +265,8 @@ public:
FragmentC1 &accum, ///< destination accumulator tile FragmentC1 &accum, ///< destination accumulator tile
IteratorA0 iterator_A, ///< iterator over A operand in global memory IteratorA0 iterator_A, ///< iterator over A operand in global memory
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
FragmentC0 const &src_accum, ///< source accumualtor tile FragmentC0 const &src_accum, ///< source accumualtor tile
OutputOp output_op_0, ///< epilogue operation after 1st Gemm OutputOp output_op_0, ///< epilogue operation after 1st Gemm
@ -387,13 +406,26 @@ public:
// Prologue // Prologue
// //
FragmentA1ScaleBias tb_frag_A1_scale;
FragmentA1ScaleBias tb_frag_A1_bias;
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
FragmentB1 tb_frag_B1; FragmentB1 tb_frag_B1;
if(PerChannelScale)
tb_frag_A1_scale.clear();
tb_frag_A1_bias.clear();
tb_frag_B1.clear(); tb_frag_B1.clear();
// The last kblock is loaded in the prolog // The last kblock is loaded in the prolog
if(PerChannelScale)
iterator_A1_scale.load(tb_frag_A1_scale);
iterator_A1_bias.load(tb_frag_A1_bias);
iterator_B1.load(tb_frag_B1); iterator_B1.load(tb_frag_B1);
if(PerChannelScale)
++iterator_A1_scale;
++iterator_A1_bias;
++iterator_B1; ++iterator_B1;
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
@ -403,15 +435,24 @@ public:
__syncthreads(); __syncthreads();
// Pair of fragments used to overlap shared memory loads and math instructions // Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA1ScaleBias warp_frag_A1_scale[2];
WarpFragmentA1ScaleBias warp_frag_A1_bias[2];
WarpFragmentA1 warp_frag_A1[2]; WarpFragmentA1 warp_frag_A1[2];
WarpFragmentB1 warp_frag_B1[2]; WarpFragmentB1 warp_frag_B1[2];
this->warp_tile_iterator_B1_.set_kgroup_index(0); this->warp_tile_iterator_B1_.set_kgroup_index(0);
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0); if(PerChannelScale)
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]);
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]);
warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0],
warp_frag_A1_bias[0], output_op_0);
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
++warp_tile_iterator_A1_; ++warp_tile_iterator_A1_;
if(PerChannelScale)
++warp_tile_iterator_A1_scale_;
++warp_tile_iterator_A1_bias_;
++this->warp_tile_iterator_B1_; ++this->warp_tile_iterator_B1_;
Operator1 warp_mma1; Operator1 warp_mma1;
@ -461,13 +502,31 @@ public:
} }
smem_write_stage_idx ^= 1; smem_write_stage_idx ^= 1;
if(PerChannelScale) {
tb_frag_A1_scale.clear();
iterator_A1_scale.load(tb_frag_A1_scale);
++iterator_A1_scale;
}
tb_frag_A1_bias.clear();
iterator_A1_bias.load(tb_frag_A1_bias);
++iterator_A1_bias;
} }
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0); if(PerChannelScale)
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]);
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]);
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2],
warp_frag_A1_scale[(warp_mma_k + 1) % 2],
warp_frag_A1_bias[(warp_mma_k + 1) % 2],
output_op_0);
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
if(PerChannelScale)
++warp_tile_iterator_A1_scale_;
++warp_tile_iterator_A1_bias_;
++warp_tile_iterator_A1_; ++warp_tile_iterator_A1_;
++this->warp_tile_iterator_B1_; ++this->warp_tile_iterator_B1_;

View File

@ -286,6 +286,8 @@ public:
FragmentC1 &accum, ///< destination accumulator tile FragmentC1 &accum, ///< destination accumulator tile
IteratorA0 iterator_A, ///< iterator over A operand in global memory IteratorA0 iterator_A, ///< iterator over A operand in global memory
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
FragmentC0 const &src_accum, ///< source accumualtor tile FragmentC0 const &src_accum, ///< source accumualtor tile
OutputOp output_op_0, ///< epilogue operation after 1st Gemm OutputOp output_op_0, ///< epilogue operation after 1st Gemm
@ -419,7 +421,7 @@ public:
/// Epilogue for the first Implicit Gemm /// Epilogue for the first Implicit Gemm
Epilogue0 epilogue0; Epilogue0 epilogue0;
epilogue0(output_op_0, smem_iterator_D0_, accum0); epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
__syncthreads(); __syncthreads();

View File

@ -40,6 +40,10 @@
#include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" #include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
#include "cutlass/transform/threadblock/vector_iterator.h"
#include "cutlass/transform/warp/vector_fragment_iterator.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
@ -170,6 +174,22 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
MmaCore1::Shape::kK, //kBlocksColumn MmaCore1::Shape::kK, //kBlocksColumn
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>;
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
static int const kElementsPerAccess = 2;
using IteratorAccumulatorScaleBias =
cutlass::transform::threadblock::VectorIterator<
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
>;
// Warp-level iterators to load scale and bias vectors
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
// Define iterators over tiles from the B operand // Define iterators over tiles from the B operand
using IteratorB1 = using IteratorB1 =
cutlass::transform::threadblock::PredicatedTileIterator< cutlass::transform::threadblock::PredicatedTileIterator<
@ -181,6 +201,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
IteratorB0, typename MmaCore0::SmemIteratorB, IteratorB0, typename MmaCore0::SmemIteratorB,
typename MmaCore1::Shape, FragmentIteratorA1, typename MmaCore1::Shape, FragmentIteratorA1,
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
IteratorB1, typename MmaCore1::SmemIteratorB, IteratorB1, typename MmaCore1::SmemIteratorB,
ElementAccumulator, layout::RowMajor, ElementAccumulator, layout::RowMajor,
EpilogueOutputOp, EpilogueOutputOp,
@ -276,6 +297,24 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
MmaCore1::Shape::kK, //kBlocksColumn MmaCore1::Shape::kK, //kBlocksColumn
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>;
/// Define iterators over tiles from scale/bias vectors
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
static int const kElementsPerAccess = 2;
using IteratorAccumulatorScaleBias =
cutlass::transform::threadblock::VectorIterator<
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
>;
// Warp-level iterators to load scale and bias vectors
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
// Define iterators over tiles from the B operand // Define iterators over tiles from the B operand
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
using AccessTypeB1 = cutlass::Array<ElementB, kAlignmentB>; using AccessTypeB1 = cutlass::Array<ElementB, kAlignmentB>;
@ -290,6 +329,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
MmaCore0::kCacheOpA, MmaCore0::kCacheOpA,
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
typename MmaCore1::Shape, FragmentIteratorA1, typename MmaCore1::Shape, FragmentIteratorA1,
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
ElementAccumulator, layout::RowMajor, ElementAccumulator, layout::RowMajor,
EpilogueOutputOp, EpilogueOutputOp,
@ -377,6 +417,22 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
ElementAccumulator, ElementA, AccumulatorLayout, ElementAccumulator, ElementA, AccumulatorLayout,
InstructionShape, EpilogueOutputOp>; InstructionShape, EpilogueOutputOp>;
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
static int const kElementsPerAccess = 4;
using IteratorAccumulatorScaleBias =
cutlass::transform::threadblock::VectorIterator<
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
>;
// Warp-level iterators to load scale and bias vectors
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
// Define iterators over tiles from the B operand // Define iterators over tiles from the B operand
using IteratorB1 = using IteratorB1 =
cutlass::transform::threadblock::PredicatedTileIterator< cutlass::transform::threadblock::PredicatedTileIterator<
@ -384,12 +440,12 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
// Define the threadblock-scoped pipelined matrix multiply // Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
IteratorB0, typename MmaCore0::SmemIteratorB, IteratorB0, typename MmaCore0::SmemIteratorB,
typename MmaCore1::Shape, FragmentIteratorA1, typename MmaCore1::Shape, FragmentIteratorA1,
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
IteratorB1, typename MmaCore1::SmemIteratorB, IteratorB1, typename MmaCore1::SmemIteratorB,
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>, ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
EpilogueOutputOp, EpilogueOutputOp,
@ -479,6 +535,23 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
ElementAccumulator, ElementA, AccumulatorLayout, ElementAccumulator, ElementA, AccumulatorLayout,
InstructionShape, EpilogueOutputOp>; InstructionShape, EpilogueOutputOp>;
/// Define iterators over tiles from scale/bias vectors
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
static int const kElementsPerAccess = 4;
using IteratorAccumulatorScaleBias =
cutlass::transform::threadblock::VectorIterator<
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
>;
// Warp-level iterators to load scale and bias vectors
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
// Define iterators over tiles from the B operand // Define iterators over tiles from the B operand
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
using IteratorB1 = using IteratorB1 =
@ -494,6 +567,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
MmaCore0::kCacheOpA, MmaCore0::kCacheOpA,
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
typename MmaCore1::Shape, FragmentIteratorA1, typename MmaCore1::Shape, FragmentIteratorA1,
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>, ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
EpilogueOutputOp, EpilogueOutputOp,

View File

@ -559,7 +559,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
cutlass::transform::threadblock::VectorIterator< cutlass::transform::threadblock::VectorIterator<
cutlass::transform::threadblock::PredicatedVectorAccessIterator< cutlass::transform::threadblock::PredicatedVectorAccessIterator<
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>, cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kK>, cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
ElementScaleBias, LayoutScaleBias, kElementsPerAccess> ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
>; >;

View File

@ -162,6 +162,8 @@ public:
if (Scale == ScaleType::OnlyAlphaScaling) return false; if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false;
if (Scale == ScaleType::Nothing) return false; if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0); return beta_ != ElementCompute(0);
@ -389,6 +391,8 @@ public:
if (Scale == ScaleType::OnlyAlphaScaling) return false; if (Scale == ScaleType::OnlyAlphaScaling) return false;
if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false;
if (Scale == ScaleType::Nothing) return false; if (Scale == ScaleType::Nothing) return false;
return beta_ != ElementCompute(0); return beta_ != ElementCompute(0);

View File

@ -82,9 +82,7 @@ __global__ void Conv2dFprop(
TensorRef<ElementC, LayoutC> tensor_y_in, TensorRef<ElementC, LayoutC> tensor_y_in,
TensorRef<ElementC, LayoutC> tensor_y_out, TensorRef<ElementC, LayoutC> tensor_y_out,
ElementCompute alpha, ElementCompute alpha,
ElementCompute beta, ElementCompute beta
TensorRef<ElementCompute, layout::RowMajor> tensor_scale,
TensorRef<ElementCompute, layout::RowMajor> tensor_bias
) { ) {
ConvertOp convert_op; ConvertOp convert_op;
@ -186,26 +184,13 @@ __global__ void Conv2dFprop(
int thread_k = k_start + n; int thread_k = k_start + n;
if (thread_k < problem_size.K) { if (thread_k < problem_size.K) {
if(alpha == ElementCompute()) { // use per-channel scale and bias ElementCompute c_ref = ElementCompute();
ElementCompute scale = tensor_scale.at({0, thread_k}); if (beta != ElementCompute()) {
ElementCompute bias = tensor_bias.at({0, thread_k}); c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
scale * ElementCompute(accum[m][n]) + bias);
} }
else if(tensor_bias.good()) { // use per-channel bias
ElementCompute bias = tensor_bias.at({0, thread_k});
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
alpha * ElementCompute(accum[m][n]) + bias);
}
else {
ElementCompute c_ref = ElementCompute();
if (beta != ElementCompute()) {
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
}
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
alpha * ElementCompute(accum[m][n]) + beta * c_ref); alpha * ElementCompute(accum[m][n]) + beta * c_ref);
}
} }
} }
} }
@ -1015,9 +1000,7 @@ Status Conv2dFprop(
TensorRef<ElementC, LayoutC> tensor_y_out, TensorRef<ElementC, LayoutC> tensor_y_out,
ElementCompute alpha, ElementCompute alpha,
ElementCompute beta, ElementCompute beta,
cudaStream_t stream = nullptr, cudaStream_t stream = nullptr) {
TensorRef<ElementCompute, layout::RowMajor> tensor_scale = TensorRef<ElementCompute, layout::RowMajor>(),
TensorRef<ElementCompute, layout::RowMajor> tensor_bias = TensorRef<ElementCompute, layout::RowMajor>() ) {
// //
// Blocking factors improve performance of reference implementation // Blocking factors improve performance of reference implementation
@ -1056,9 +1039,7 @@ Status Conv2dFprop(
tensor_y_in, tensor_y_in,
tensor_y_out, tensor_y_out,
alpha, alpha,
beta, beta
tensor_scale,
tensor_bias
); );
cudaError_t result = cudaPeekAtLastError(); cudaError_t result = cudaPeekAtLastError();
@ -1448,9 +1429,7 @@ Status Conv2d(
TensorRef<ElementC, LayoutC> tensor_D, TensorRef<ElementC, LayoutC> tensor_D,
ElementCompute alpha, ElementCompute alpha,
ElementCompute beta, ElementCompute beta,
cudaStream_t stream = nullptr, cudaStream_t stream = nullptr) {
TensorRef<ElementCompute, layout::RowMajor> tensor_scale = TensorRef<ElementCompute, layout::RowMajor>(),
TensorRef<ElementCompute, layout::RowMajor> tensor_bias = TensorRef<ElementCompute, layout::RowMajor>() ) {
switch (convolutional_operator) { switch (convolutional_operator) {
case conv::Operator::kFprop: case conv::Operator::kFprop:
@ -1461,7 +1440,7 @@ Status Conv2d(
ElementCompute, ElementCompute,
ElementAccumulator, ElementAccumulator,
ConvertOp, InnerProductOp ConvertOp, InnerProductOp
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream, tensor_scale, tensor_bias); >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
break; break;
case conv::Operator::kDgrad: case conv::Operator::kDgrad: