b2b bias vector support (#482)
* b2b bias vector support * add files Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
86ce09aed1
commit
ec2b4fd85d
@ -61,6 +61,29 @@ When applying the above constraint to convolutions, it is required that the 2nd
|
||||
kernel doesn't have halos such that data used by each threadblock doesn't depend on any other
|
||||
threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without any paddings.
|
||||
|
||||
# Build and run
|
||||
|
||||
- Run cmake at top-level CUTLASS
|
||||
- `make 13_two_tensor_op_fusion`
|
||||
- Run individual benchmarks
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm75_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_f16_sm80_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm75_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_convs_s8_sm80_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm75_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_f16_sm80_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
|
||||
@ -54,6 +54,7 @@
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -153,6 +154,7 @@ public:
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
std::cerr << "Not implemented\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -407,6 +409,7 @@ public:
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
||||
cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
||||
@ -487,6 +490,7 @@ public:
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.K});
|
||||
tensor_Bias0.resize({1, problem_size_0.K});
|
||||
tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1));
|
||||
@ -607,22 +611,35 @@ public:
|
||||
typename B2bConv2d::LayoutA,
|
||||
typename B2bConv2d::ElementB,
|
||||
typename B2bConv2d::LayoutB,
|
||||
typename B2bConv2d::ElementC,
|
||||
ElementAccumulator,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Z0_reference.device_ref(),
|
||||
tensor_Z0_reference.device_ref(),
|
||||
ElementAccumulator(1), // intermediate alpha = 1
|
||||
ElementAccumulator(0) // beta = 0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasConv2d<
|
||||
ElementAccumulator,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
typename B2bConv2d::LayoutScaleBias
|
||||
>(
|
||||
problem_size_0,
|
||||
tensor_Z0_reference.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0,
|
||||
nullptr, // stream
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref());
|
||||
tensor_Bias0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
|
||||
@ -44,6 +44,7 @@
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -68,6 +69,7 @@ struct B2bNonFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
@ -78,9 +80,10 @@ struct B2bNonFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
@ -107,8 +110,13 @@ struct B2bNonFusedGemmRun
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
@ -147,6 +155,10 @@ struct B2bNonFusedGemmRun
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
@ -163,6 +175,10 @@ struct B2bNonFusedGemmRun
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
@ -175,8 +191,10 @@ struct B2bNonFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D0.host_view());
|
||||
@ -190,9 +208,11 @@ struct B2bNonFusedGemmRun
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
@ -205,7 +225,7 @@ struct B2bNonFusedGemmRun
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
tensor_D0.device_ref(),
|
||||
{alpha0, beta0}
|
||||
};
|
||||
@ -214,7 +234,7 @@ struct B2bNonFusedGemmRun
|
||||
problem_size_1,
|
||||
tensor_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha1, beta1}
|
||||
};
|
||||
@ -241,7 +261,6 @@ struct B2bNonFusedGemmRun
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
@ -256,7 +275,6 @@ struct B2bNonFusedGemmRun
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
@ -298,7 +316,7 @@ struct B2bNonFusedGemmRun
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
@ -312,7 +330,7 @@ struct B2bNonFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
@ -325,7 +343,6 @@ struct B2bNonFusedGemmRun
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
@ -349,13 +366,14 @@ struct B2bNonFusedGemmRun
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
@ -372,6 +390,8 @@ struct B2bFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
@ -382,9 +402,12 @@ struct B2bFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
@ -411,8 +434,13 @@ struct B2bFusedGemmRun
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
@ -451,6 +479,21 @@ struct B2bFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
@ -463,6 +506,10 @@ struct B2bFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
@ -475,8 +522,12 @@ struct B2bFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
@ -488,8 +539,12 @@ struct B2bFusedGemmRun
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
@ -504,8 +559,10 @@ struct B2bFusedGemmRun
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
@ -524,7 +581,6 @@ struct B2bFusedGemmRun
|
||||
<< " ThreadblockShape1::kN = problem_size_1.N" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
@ -561,21 +617,42 @@ struct B2bFusedGemmRun
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_0, reference_gemm_1;
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
reference_D0.device_ref()
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
problem_size_0,
|
||||
reference_Z0.device_ref(),
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
@ -588,19 +665,16 @@ struct B2bFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
@ -610,7 +684,8 @@ struct B2bFusedGemmRun
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed) {
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
@ -623,12 +698,14 @@ struct B2bFusedGemmRun
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
|
||||
@ -55,6 +55,7 @@
|
||||
#include "cutlass/core_io.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -91,14 +92,14 @@ public:
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementB, typename Conv2d0::LayoutB> tensor_B0_reordered;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_Bias0;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_computed;
|
||||
cutlass::HostTensor<typename Conv2d0::ElementC, typename Conv2d0::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementB, typename Conv2d1::LayoutB> tensor_B1_reordered;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementCompute, typename Conv2d0::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d0::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename Conv2d1::ElementC, typename Conv2d1::LayoutC> tensor_D1_reference;
|
||||
|
||||
@ -379,11 +380,13 @@ public:
|
||||
<< "\nB0:\n" << tensor_B0.host_view() << "\n"
|
||||
<< "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n"
|
||||
<< "\nC0:\n" << tensor_C0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n"
|
||||
<< "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n"
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
@ -421,12 +424,13 @@ public:
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Scale0;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementScaleBias, typename B2bConv2d::LayoutScaleBias> tensor_Bias0;
|
||||
cutlass::HostTensor<ElementAccumulator, typename B2bConv2d::LayoutC> tensor_Z0_reference;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D0_reference;
|
||||
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementB, typename B2bConv2d::LayoutB> tensor_B1_reordered;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_C1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementCompute, typename B2bConv2d::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_Bias1;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_computed;
|
||||
cutlass::HostTensor<typename B2bConv2d::ElementC, typename B2bConv2d::LayoutC> tensor_D1_reference;
|
||||
|
||||
@ -503,6 +507,7 @@ public:
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.K});
|
||||
tensor_Bias0.resize({1, problem_size_0.K});
|
||||
tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0));
|
||||
tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1));
|
||||
@ -632,23 +637,36 @@ public:
|
||||
typename B2bConv2d::LayoutA,
|
||||
typename B2bConv2d::ElementB,
|
||||
typename B2bConv2d::LayoutB,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator
|
||||
>(
|
||||
kConvolutionalOperator,
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Z0_reference.device_ref(),
|
||||
tensor_Z0_reference.device_ref(),
|
||||
ElementAccumulator(1), // intermediate alpha = 1
|
||||
ElementAccumulator(0) // beta = 0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasConv2d<
|
||||
ElementAccumulator,
|
||||
typename B2bConv2d::ElementC,
|
||||
typename B2bConv2d::LayoutC,
|
||||
ElementCompute,
|
||||
typename B2bConv2d::LayoutScaleBias,
|
||||
cutlass::NumericConverterClamp<typename B2bConv2d::ElementC, ElementCompute>
|
||||
>(
|
||||
problem_size_0,
|
||||
tensor_Z0_reference.device_ref(),
|
||||
tensor_D0_reference.device_ref(),
|
||||
alpha0,
|
||||
beta0,
|
||||
nullptr, // stream
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref());
|
||||
tensor_Bias0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view());
|
||||
@ -716,6 +734,7 @@ public:
|
||||
<< "\nB1:\n" << tensor_B1.host_view() << "\n"
|
||||
<< "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n"
|
||||
<< "\nC1:\n" << tensor_C1.host_view() << "\n"
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n"
|
||||
<< "\nD1 computed:\n" << tensor_D1_computed.host_view();
|
||||
|
||||
|
||||
@ -28,7 +28,6 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
@ -46,6 +45,7 @@
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
@ -68,6 +68,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
@ -78,9 +79,10 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
@ -98,13 +100,22 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
@ -147,6 +158,10 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
@ -167,6 +182,10 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
@ -179,8 +198,10 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
||||
|
||||
//Reorder B0 and B1
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
@ -201,10 +222,12 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
@ -217,7 +240,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
tensor_D0.device_ref(),
|
||||
{alpha0, beta0}
|
||||
};
|
||||
@ -226,7 +249,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
problem_size_1,
|
||||
tensor_D0.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha1, beta1}
|
||||
};
|
||||
@ -266,7 +289,6 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
@ -286,7 +308,6 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
@ -310,7 +331,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
@ -324,7 +345,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
@ -332,6 +353,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -341,7 +363,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
@ -360,10 +382,12 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
@ -383,6 +407,8 @@ struct B2bInterleavedFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
cutlass::Distribution::Kind init_Scale;
|
||||
cutlass::Distribution::Kind init_Bias;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
@ -393,9 +419,12 @@ struct B2bInterleavedFusedGemmRun
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
||||
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
@ -413,13 +442,22 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
||||
cutlass::reference::host::TensorFill(view, Element(0));
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
||||
cutlass::reference::host::TensorFill(view, Element(1));
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
@ -462,6 +500,21 @@ struct B2bInterleavedFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
@ -478,6 +531,10 @@ struct B2bInterleavedFusedGemmRun
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
@ -490,8 +547,12 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
||||
|
||||
//Reorder B0
|
||||
cutlass::reorder_column<16>(
|
||||
@ -510,9 +571,13 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.sync_device();
|
||||
tensor_Bias0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_Bias1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
@ -527,12 +592,13 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
1, /*threadblock_swizzle_k_tile*/
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
@ -581,25 +647,45 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
tensor_D1.sync_host();
|
||||
|
||||
bool passed = false;
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_0, reference_gemm_1;
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
reference_D0.device_ref()
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
problem_size_0,
|
||||
reference_Z0.device_ref(),
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
@ -612,15 +698,12 @@ struct B2bInterleavedFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -629,12 +712,13 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
passed = cutlass::reference::host::TensorEquals(
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed) {
|
||||
if (!passed)
|
||||
{
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
@ -648,9 +732,12 @@ struct B2bInterleavedFusedGemmRun
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nScale0:\n" << tensor_Scale0.host_view() << "\n"
|
||||
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
@ -158,6 +158,10 @@ class B2bGemm {
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Derived types
|
||||
using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor;
|
||||
|
||||
/// Define the kernel
|
||||
using B2bGemmKernel = typename kernel::DefaultB2bGemm<
|
||||
ElementA,
|
||||
@ -197,6 +201,8 @@ class B2bGemm {
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0;
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
@ -222,6 +228,8 @@ class B2bGemm {
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Scale0_,
|
||||
TensorRef<ElementScaleBias const, LayoutScaleBias> ref_Bias0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
@ -236,6 +244,8 @@ class B2bGemm {
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_Scale0(ref_Scale0_),
|
||||
ref_Bias0(ref_Bias0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
@ -348,6 +358,8 @@ public:
|
||||
args.ref_A0.non_const_ref(),
|
||||
args.ref_B0.non_const_ref(),
|
||||
args.ref_C0.non_const_ref(),
|
||||
args.ref_Scale0.non_const_ref(),
|
||||
args.ref_Bias0.non_const_ref(),
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1,
|
||||
@ -368,12 +380,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A0.reset(args.ref_A.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C0.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_B1.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C1.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_D1.reset(args.ref_D.data());
|
||||
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
||||
params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
|
||||
params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data());
|
||||
params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data());
|
||||
params_.ref_B1.reset(args.ref_B1.non_const_ref().data());
|
||||
params_.ref_C1.reset(args.ref_C1.non_const_ref().data());
|
||||
params_.ref_D1.reset(args.ref_D1.data());
|
||||
params_.output_op_0 = args.epilogue0;
|
||||
params_.output_op_1 = args.epilogue1;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -151,14 +151,15 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_rf_res() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //use beta for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
|
||||
@ -68,13 +68,13 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -118,7 +118,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm75() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -151,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -176,7 +177,7 @@ bool run_fused_conv2d_fprop_optimized_f16_sm75_shmem() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
@ -69,14 +69,14 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -94,7 +94,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -118,7 +118,8 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_rf_res() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using B2bConv2dFpropKernel = typename cutlass::conv::kernel::DefaultB2bConv2dFprop<
|
||||
|
||||
@ -69,13 +69,13 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
@ -94,7 +94,7 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -118,7 +118,8 @@ bool run_nonfused_conv2d_fprop_optimized_f16_sm80() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -151,9 +152,10 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() {
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_f16_sm80_shmem() {
|
||||
ElementC,
|
||||
128 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -151,14 +152,15 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm75() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2,
|
||||
@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
@ -68,14 +68,14 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Conv2dFpropKernel0 = typename cutlass::conv::kernel::DefaultConv2dFprop<
|
||||
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -151,14 +152,15 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -175,7 +177,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_rf_res() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
@ -68,13 +68,13 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
@ -93,7 +93,7 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -117,7 +117,8 @@ bool run_nonfused_conv2d_fprop_optimized_s8_sm80() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3,
|
||||
@ -150,9 +151,10 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() {
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
@ -174,7 +176,8 @@ bool run_fused_conv2d_fprop_optimized_s8_sm80_shmem() {
|
||||
ElementC,
|
||||
64 / cutlass::sizeof_bits<ElementC>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
@ -55,10 +55,10 @@ bool run_nonfused_gemm_f16() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16() {
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -131,10 +132,11 @@ bool run_fused_gemm_f16_rf_res() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
@ -156,7 +158,8 @@ bool run_fused_gemm_f16_rf_res() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
|
||||
@ -55,14 +55,14 @@ bool run_nonfused_gemm_f16() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16() {
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -130,10 +131,11 @@ bool run_fused_gemm_f16_shmem() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -155,7 +157,8 @@ bool run_fused_gemm_f16_shmem() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
@ -55,15 +55,15 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
@ -130,15 +131,16 @@ bool run_fused_gemm_f16_sm80_rf_res() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -155,11 +157,10 @@ bool run_fused_gemm_f16_sm80_rf_res() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
|
||||
@ -55,10 +55,10 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -84,7 +84,7 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
@ -106,7 +106,8 @@ bool run_nonfused_gemm_f16_sm80() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
3
|
||||
@ -130,10 +131,11 @@ bool run_fused_gemm_f16_sm80_shmem() {
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -155,7 +157,8 @@ bool run_fused_gemm_f16_sm80_shmem() {
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
|
||||
@ -55,10 +55,10 @@ bool run_nonfused_gemm_s8() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -106,7 +106,8 @@ bool run_nonfused_gemm_s8() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -131,10 +132,11 @@ bool run_fused_gemm_s8_rf_res() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 32>;
|
||||
@ -156,7 +158,8 @@ bool run_fused_gemm_s8_rf_res() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
@ -200,7 +203,7 @@ int main() {
|
||||
&run_fused_gemm_s8_rf_res
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm f16 RF residency");
|
||||
return testRun(75, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
@ -55,15 +55,15 @@ bool run_nonfused_gemm_s8() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta = 1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta = 1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 256, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -106,7 +106,8 @@ bool run_nonfused_gemm_s8() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
@ -130,10 +131,11 @@ bool run_fused_gemm_s8_shmem() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 32>;
|
||||
@ -155,7 +157,8 @@ bool run_fused_gemm_s8_shmem() {
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
@ -202,7 +205,7 @@ int main() {
|
||||
&run_fused_gemm_s8_shmem
|
||||
};
|
||||
|
||||
return testRun(75, funcs, "gemm s8 shmem staing");
|
||||
return testRun(75, funcs, "gemm int8 shmem staing");
|
||||
|
||||
|
||||
}
|
||||
|
||||
@ -55,15 +55,15 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
@ -111,7 +111,7 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
@ -140,15 +140,16 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
@ -166,7 +167,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
@ -55,14 +55,14 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
ElementCompute beta0 = ElementCompute(1); //beta=1 for bias
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
@ -84,7 +84,7 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
@ -111,7 +111,7 @@ bool run_nonfused_gemm_s8_sm80() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
@ -139,10 +139,11 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
@ -165,7 +166,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
@ -79,6 +79,8 @@ struct B2bGemm {
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue::OutputTileIterator::Params params_C0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0;
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0;
|
||||
typename B2bMma::IteratorB1::Params params_B1;
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue::OutputTileIterator::Params params_C1;
|
||||
@ -109,6 +111,8 @@ struct B2bGemm {
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0,
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0,
|
||||
typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0,
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
||||
@ -126,6 +130,8 @@ struct B2bGemm {
|
||||
ref_B0(ref_B0),
|
||||
params_C0(ref_C0.layout()),
|
||||
ref_C0(ref_C0),
|
||||
ref_Scale0(ref_Scale0),
|
||||
ref_Bias0(ref_Bias0),
|
||||
params_B1(ref_B1.layout()),
|
||||
ref_B1(ref_B1),
|
||||
params_C1(ref_C1.layout()),
|
||||
@ -305,6 +311,29 @@ struct B2bGemm {
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||
params.ref_Scale0.data(),
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
MatrixCoord(
|
||||
0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||
params.ref_Bias0.data(),
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
MatrixCoord(
|
||||
0, threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
@ -322,7 +351,8 @@ struct B2bGemm {
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0);
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -338,7 +338,7 @@ struct DefaultB2bConv2dFprop <
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kK>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
|
||||
@ -0,0 +1,275 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Defines device-side elementwise operations on TensorView. Note, the operations defined
|
||||
in this header are not specialized for any particular data layout and are therefore not
|
||||
intended to offer the best possible performance. Rather, they are intended to be generic
|
||||
reference implementations to support the CUTLASS unit tests.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// Cutlass includes
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace reference {
|
||||
namespace device {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace kernel {
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||
typename OutputTile,
|
||||
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>
|
||||
>
|
||||
__global__ void TensorScaleBiasGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRefIn tensor_in, ///< input tensor
|
||||
TensorRefOut tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
MatrixCoord output_coord(
|
||||
MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow),
|
||||
MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn)
|
||||
);
|
||||
|
||||
// Update the output tensor
|
||||
for (int j = 0; j < OutputTile::kRow; ++j) {
|
||||
for (int i = 0; i < OutputTile::kColumn; ++i) {
|
||||
MatrixCoord coord = output_coord + MatrixCoord(i, j);
|
||||
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
||||
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, coord.column()});
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
scale * ScalarType(tensor_in.at(coord)) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
|
||||
int kThreadM = 4, // shape of a thread's tile in the GEMM M dimension
|
||||
int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension
|
||||
int kCtaShapeM = 16, // shape of a threadblock in units of threads
|
||||
int kCtaShapeN = 8 // shape of a threadblock in units of threads
|
||||
>
|
||||
__global__ void TensorScaleBiasConv2d(
|
||||
conv::Conv2dProblemSize problem_size,
|
||||
TensorRefIn tensor_in, ///< input tensor
|
||||
TensorRefOut tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
||||
int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;
|
||||
|
||||
int thread_n[kThreadM];
|
||||
int thread_p[kThreadM];
|
||||
int thread_q[kThreadM];
|
||||
|
||||
// Compute N, P, Q coordinates for each row of a thread's tile
|
||||
int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < kThreadM; ++m) {
|
||||
|
||||
int64_t npq = npq_start + m;
|
||||
|
||||
thread_n[m] = int(npq / PQ);
|
||||
|
||||
int64_t residual = npq % PQ;
|
||||
thread_p[m] = int(residual / problem_size.Q);
|
||||
thread_q[m] = int(residual % problem_size.Q);
|
||||
}
|
||||
|
||||
// Write out the results
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < kThreadM; ++m) {
|
||||
if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < kThreadN; ++n) {
|
||||
int thread_k = k_start + n;
|
||||
if (thread_k < problem_size.K) {
|
||||
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, thread_k});
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, thread_k});
|
||||
|
||||
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
scale * ScalarType(
|
||||
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
|
||||
) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
typename ElementOut, ///< Output Type
|
||||
typename Layout, ///< Layout of input/output tensor
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||
>
|
||||
void TensorScaleBiasGemm(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
using OutputTile = MatrixShape<4, 4>;
|
||||
|
||||
dim3 block(16, 8);
|
||||
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
|
||||
(problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
|
||||
);
|
||||
|
||||
kernel::TensorScaleBiasGemm<
|
||||
TensorRef<ElementIn, Layout>,
|
||||
TensorRef<ElementOut, Layout>,
|
||||
ScalarType,
|
||||
TensorRef<ScalarType, LayoutScaleBias>,
|
||||
OutputTile,
|
||||
ConvertOp
|
||||
><<< grid, block >>> (
|
||||
problem_size,
|
||||
tensor_in,
|
||||
tensor_out,
|
||||
alpha,
|
||||
tensor_scale,
|
||||
tensor_bias
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
typename ElementOut, ///< Output Type
|
||||
typename Layout, ///< Layout of input/output tensor
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||
>
|
||||
void TensorScaleBiasConv2d(
|
||||
conv::Conv2dProblemSize problem_size,
|
||||
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension
|
||||
int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension
|
||||
int const kCtaShapeM = 16; // shape of a threadblock in units of threads
|
||||
int const kCtaShapeN = 8; // shape of a threadblock in units of threads
|
||||
|
||||
int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
|
||||
int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);
|
||||
|
||||
dim3 block(kCtaShapeM, kCtaShapeN);
|
||||
dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN));
|
||||
|
||||
|
||||
kernel::TensorScaleBiasConv2d<
|
||||
TensorRef<ElementIn, Layout>,
|
||||
TensorRef<ElementOut, Layout>,
|
||||
ScalarType,
|
||||
TensorRef<ScalarType, LayoutScaleBias>,
|
||||
ConvertOp,
|
||||
kThreadM,
|
||||
kThreadN,
|
||||
kCtaShapeM,
|
||||
kCtaShapeN
|
||||
><<< grid, block >>> (
|
||||
problem_size,
|
||||
tensor_in,
|
||||
tensor_out,
|
||||
alpha,
|
||||
tensor_scale,
|
||||
tensor_bias
|
||||
);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace reference
|
||||
} // namespace cutlass
|
||||
@ -745,7 +745,6 @@ public:
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2],
|
||||
warp_transformed_frag_B1[warp_mma_k % 2],
|
||||
|
||||
@ -82,6 +82,11 @@ template <
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// WarpIterator to load Scale or Bias vector from threadblock fragment
|
||||
typename FragmentIteratorA1ScaleBias_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -126,6 +131,10 @@ public:
|
||||
using Shape1 = Shape1_;
|
||||
///< Iterates over intermediate accumulator tile
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_;
|
||||
///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_;
|
||||
///< WarpIterator to load Scale or Bias vector from threadblock fragment
|
||||
using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB1 = IteratorB1_;
|
||||
///< Policy describing tuning details
|
||||
@ -141,6 +150,9 @@ public:
|
||||
///< Epilogue after 1st Gemm
|
||||
using OutputOp = OutputOp_;
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
|
||||
@ -155,6 +167,9 @@ public:
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
@ -217,6 +232,8 @@ public:
|
||||
using WarpLoadedFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
using WarpLoadedFragmentA1ScaleBias =
|
||||
typename FragmentIteratorA1ScaleBias::Fragment;
|
||||
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
|
||||
using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB;
|
||||
@ -381,11 +398,15 @@ public:
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A operand in global memory
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B operand in global memory
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over B operand in global memory
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
@ -623,6 +644,20 @@ public:
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
FragmentA1ScaleBias tb_frag_A1_scale;
|
||||
FragmentA1ScaleBias tb_frag_A1_bias;
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
|
||||
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
|
||||
|
||||
//
|
||||
// Prologue
|
||||
@ -678,18 +713,29 @@ public:
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
|
||||
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2];
|
||||
WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2];
|
||||
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
|
||||
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
|
||||
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
if(PerChannelScale) {
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
}
|
||||
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]);
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0],
|
||||
warp_loaded_frag_A1_scale[0],
|
||||
warp_loaded_frag_A1_bias[0],
|
||||
output_op_0);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]);
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
|
||||
@ -717,15 +763,37 @@ public:
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
|
||||
++warp_mma_k) {
|
||||
|
||||
// Load threadblock-level scale/bias vector from global memory
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations1) {
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
}
|
||||
|
||||
// Load warp-level scale bias fragment from threadblock scale/bias vector
|
||||
if(PerChannelScale) {
|
||||
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
}
|
||||
warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]);
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
|
||||
// Load warp-level tile from accumulator fragment
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2],
|
||||
warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2],
|
||||
output_op_0);
|
||||
++warp_tile_iterator_A1_;
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k > 0)
|
||||
|
||||
@ -166,6 +166,9 @@ public:
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
@ -418,11 +421,15 @@ public:
|
||||
int gemm_k_iterations_0,
|
||||
///< destination accumulator tile
|
||||
FragmentC1 &accum,
|
||||
///< iterator over A operand in global memory
|
||||
///< iterator over A0 operand in global memory
|
||||
IteratorA0 iterator_A0,
|
||||
///< iterator over B operand in global memory
|
||||
///< iterator over B0 operand in global memory
|
||||
IteratorB0 iterator_B0,
|
||||
///< iterator over B operand in global memory
|
||||
///< iterator over A1 operand scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale,
|
||||
///< iterator over A1 operand bias vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias,
|
||||
///< iterator over B1 operand in global memory
|
||||
IteratorB1 iterator_B1,
|
||||
///< initial value of accumulator
|
||||
FragmentC0 const &src_accum,
|
||||
@ -658,7 +665,7 @@ public:
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0);
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@ -76,6 +76,11 @@ template <
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over vectors of scale and bias vector in global memory
|
||||
// (concept: VectorIterator)
|
||||
typename IteratorAccumulatorScaleBias_,
|
||||
/// FragmentIterator to load Scale or Bias vector from threadblock fragment
|
||||
typename FragmentIteratorA1ScaleBias_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
@ -129,6 +134,9 @@ public:
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile
|
||||
using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory
|
||||
using FragmentIteratorA1ScaleBias =
|
||||
FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy describing tuning details
|
||||
|
||||
@ -140,6 +148,9 @@ public:
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
static const bool PerChannelScale = (OutputOp::kScale ==
|
||||
epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling);
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
@ -160,6 +171,9 @@ public:
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of Scale and Bias loaded from global memory
|
||||
using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
@ -190,6 +204,9 @@ private:
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
/// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment
|
||||
using WarpFragmentA1ScaleBias =
|
||||
typename FragmentIteratorA1ScaleBias::Fragment;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
@ -248,6 +265,8 @@ public:
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory
|
||||
IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
@ -387,13 +406,26 @@ public:
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentA1ScaleBias tb_frag_A1_scale;
|
||||
FragmentA1ScaleBias tb_frag_A1_bias;
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale);
|
||||
FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias);
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
if(PerChannelScale)
|
||||
tb_frag_A1_scale.clear();
|
||||
tb_frag_A1_bias.clear();
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
if(PerChannelScale)
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
if(PerChannelScale)
|
||||
++iterator_A1_scale;
|
||||
++iterator_A1_bias;
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(transform_B1(tb_frag_B1));
|
||||
@ -403,15 +435,24 @@ public:
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1ScaleBias warp_frag_A1_scale[2];
|
||||
WarpFragmentA1ScaleBias warp_frag_A1_bias[2];
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
|
||||
if(PerChannelScale)
|
||||
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]);
|
||||
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0],
|
||||
warp_frag_A1_bias[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
if(PerChannelScale)
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
@ -461,13 +502,31 @@ public:
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
|
||||
if(PerChannelScale) {
|
||||
tb_frag_A1_scale.clear();
|
||||
iterator_A1_scale.load(tb_frag_A1_scale);
|
||||
++iterator_A1_scale;
|
||||
}
|
||||
tb_frag_A1_bias.clear();
|
||||
iterator_A1_bias.load(tb_frag_A1_bias);
|
||||
++iterator_A1_bias;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
if(PerChannelScale)
|
||||
warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]);
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2],
|
||||
warp_frag_A1_scale[(warp_mma_k + 1) % 2],
|
||||
warp_frag_A1_bias[(warp_mma_k + 1) % 2],
|
||||
output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
if(PerChannelScale)
|
||||
++warp_tile_iterator_A1_scale_;
|
||||
++warp_tile_iterator_A1_bias_;
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
|
||||
@ -286,6 +286,8 @@ public:
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory
|
||||
IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
@ -419,7 +421,7 @@ public:
|
||||
/// Epilogue for the first Implicit Gemm
|
||||
Epilogue0 epilogue0;
|
||||
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0);
|
||||
epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@ -40,6 +40,10 @@
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h"
|
||||
#include "cutlass/transform/threadblock/vector_iterator.h"
|
||||
#include "cutlass/transform/warp/vector_fragment_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
@ -170,6 +174,22 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
@ -181,6 +201,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
@ -276,6 +297,24 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 2;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using AccessTypeB1 = cutlass::Array<ElementB, kAlignmentB>;
|
||||
@ -290,6 +329,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
@ -377,6 +417,22 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
@ -384,12 +440,12 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
@ -479,6 +535,23 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp>;
|
||||
|
||||
/// Define iterators over tiles from scale/bias vectors
|
||||
using ElementScaleBias = typename EpilogueOutputOp::ElementCompute;
|
||||
using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter
|
||||
static int const kElementsPerAccess = 4;
|
||||
using IteratorAccumulatorScaleBias =
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape1::kM, WarpShape1::kK>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
// Warp-level iterators to load scale and bias vectors
|
||||
using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator<
|
||||
MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias,
|
||||
LayoutScaleBias, InstructionShape, kElementsPerAccess>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB;
|
||||
using IteratorB1 =
|
||||
@ -494,6 +567,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
MmaCore0::kCacheOpA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
|
||||
@ -559,7 +559,7 @@ struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
cutlass::transform::threadblock::VectorIterator<
|
||||
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape0::kM, ThreadblockShape0::kN>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kK>,
|
||||
cutlass::MatrixShape<WarpShape0::kM, WarpShape0::kN>,
|
||||
ElementScaleBias, LayoutScaleBias, kElementsPerAccess>
|
||||
>;
|
||||
|
||||
|
||||
@ -162,6 +162,8 @@ public:
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
@ -389,6 +391,8 @@ public:
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::OnlyAlphaPerChannelScaling) return false;
|
||||
|
||||
if (Scale == ScaleType::Nothing) return false;
|
||||
|
||||
return beta_ != ElementCompute(0);
|
||||
|
||||
@ -82,9 +82,7 @@ __global__ void Conv2dFprop(
|
||||
TensorRef<ElementC, LayoutC> tensor_y_in,
|
||||
TensorRef<ElementC, LayoutC> tensor_y_out,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
TensorRef<ElementCompute, layout::RowMajor> tensor_scale,
|
||||
TensorRef<ElementCompute, layout::RowMajor> tensor_bias
|
||||
ElementCompute beta
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
@ -186,26 +184,13 @@ __global__ void Conv2dFprop(
|
||||
int thread_k = k_start + n;
|
||||
if (thread_k < problem_size.K) {
|
||||
|
||||
if(alpha == ElementCompute()) { // use per-channel scale and bias
|
||||
ElementCompute scale = tensor_scale.at({0, thread_k});
|
||||
ElementCompute bias = tensor_bias.at({0, thread_k});
|
||||
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
scale * ElementCompute(accum[m][n]) + bias);
|
||||
ElementCompute c_ref = ElementCompute();
|
||||
if (beta != ElementCompute()) {
|
||||
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
|
||||
}
|
||||
else if(tensor_bias.good()) { // use per-channel bias
|
||||
ElementCompute bias = tensor_bias.at({0, thread_k});
|
||||
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
alpha * ElementCompute(accum[m][n]) + bias);
|
||||
}
|
||||
else {
|
||||
ElementCompute c_ref = ElementCompute();
|
||||
if (beta != ElementCompute()) {
|
||||
c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}));
|
||||
}
|
||||
|
||||
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
||||
}
|
||||
tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
alpha * ElementCompute(accum[m][n]) + beta * c_ref);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1015,9 +1000,7 @@ Status Conv2dFprop(
|
||||
TensorRef<ElementC, LayoutC> tensor_y_out,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
cudaStream_t stream = nullptr,
|
||||
TensorRef<ElementCompute, layout::RowMajor> tensor_scale = TensorRef<ElementCompute, layout::RowMajor>(),
|
||||
TensorRef<ElementCompute, layout::RowMajor> tensor_bias = TensorRef<ElementCompute, layout::RowMajor>() ) {
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
//
|
||||
// Blocking factors improve performance of reference implementation
|
||||
@ -1056,9 +1039,7 @@ Status Conv2dFprop(
|
||||
tensor_y_in,
|
||||
tensor_y_out,
|
||||
alpha,
|
||||
beta,
|
||||
tensor_scale,
|
||||
tensor_bias
|
||||
beta
|
||||
);
|
||||
|
||||
cudaError_t result = cudaPeekAtLastError();
|
||||
@ -1448,9 +1429,7 @@ Status Conv2d(
|
||||
TensorRef<ElementC, LayoutC> tensor_D,
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
cudaStream_t stream = nullptr,
|
||||
TensorRef<ElementCompute, layout::RowMajor> tensor_scale = TensorRef<ElementCompute, layout::RowMajor>(),
|
||||
TensorRef<ElementCompute, layout::RowMajor> tensor_bias = TensorRef<ElementCompute, layout::RowMajor>() ) {
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
switch (convolutional_operator) {
|
||||
case conv::Operator::kFprop:
|
||||
@ -1461,7 +1440,7 @@ Status Conv2d(
|
||||
ElementCompute,
|
||||
ElementAccumulator,
|
||||
ConvertOp, InnerProductOp
|
||||
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream, tensor_scale, tensor_bias);
|
||||
>(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream);
|
||||
break;
|
||||
|
||||
case conv::Operator::kDgrad:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user