added support of b2b bmm (#849)
* added support of b2b bmm * fixed arguments and params structures * added batch_count argument * removed SplitKSerial and added new test case with b2b bmm * fixed support of Kbatched and added new test case with batch stride * added batch support for bias and scale * make test * small changes --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
d572cc1aab
commit
4a68cf748e
@ -1,11 +1,11 @@
|
||||
# Introduction
|
||||
|
||||
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
|
||||
This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_fusion.png></p>
|
||||
|
||||
When running two unfused GEMM/Conv operations, each operation loads one input
|
||||
activation matrix, one weight matrix (or filter matrix) from the memory and then
|
||||
When running two unfused GEMM/Conv operations, each operation loads one input
|
||||
activation matrix, one weight matrix (or filter matrix) from the memory and then
|
||||
stores the result activation matrix back to the memory.
|
||||
|
||||
When the two GEMM/Conv operations are fused together, the mainloops of the two
|
||||
@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same
|
||||
threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across
|
||||
2 GEMMs/Convs.
|
||||
|
||||
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
|
||||
In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the
|
||||
input activation, the example enforces the following two constraints:
|
||||
|
||||
- thread_block_tile_N = problem_N
|
||||
- thread_block_tile_N = problem_N
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_block_resident_fusion.png></p>
|
||||
|
||||
@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o
|
||||
2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the
|
||||
operation can be fully block-resident.
|
||||
|
||||
- warp_tile_N = thread_block_tile_N
|
||||
- warp_tile_N = thread_block_tile_N
|
||||
|
||||
<p align="center"><img src=/media/images/13_example_rf_resident_fusion.png></p>
|
||||
|
||||
@ -82,7 +82,7 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf`
|
||||
- `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem`
|
||||
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
|
@ -42,6 +42,7 @@
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun
|
||||
//
|
||||
|
||||
B2bNonFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
ElementCompute,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
ElementCompute,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
@ -399,9 +400,9 @@ struct B2bFusedGemmRun
|
||||
//
|
||||
|
||||
B2bFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
@ -412,7 +413,7 @@ struct B2bFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -420,11 +421,11 @@ struct B2bFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
@ -453,70 +454,90 @@ struct B2bFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
// batch_count is used as split-k when mode is kGemm according
|
||||
// to the GemmUniversal interface
|
||||
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_A0 = 0,
|
||||
int64_t batch_stride_B0 = 0,
|
||||
int64_t batch_stride_C0 = 0,
|
||||
int64_t batch_stride_B1 = 0,
|
||||
int64_t batch_stride_C1 = 0,
|
||||
int64_t batch_stride_D1 = 0,
|
||||
int64_t batch_stride_Bias0 = 0,
|
||||
int64_t batch_stride_Scale0 = 0,
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
||||
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementCompute,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
@ -554,6 +575,7 @@ struct B2bFusedGemmRun
|
||||
//
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
mode,
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
@ -564,8 +586,16 @@ struct B2bFusedGemmRun
|
||||
tensor_B1.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0,
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
batch_count,
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
@ -618,32 +648,31 @@ struct B2bFusedGemmRun
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator
|
||||
>(
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_A0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_C0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
@ -652,25 +681,45 @@ struct B2bFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
tensor_Bias0.device_ref(),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_C0,
|
||||
batch_stride_Scale0,
|
||||
batch_stride_Bias0
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
alpha1, //intermediate alpha=1
|
||||
reference_D0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B1.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
beta1, //beta = 0
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
reference_D1.device_ref(),
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -680,7 +729,7 @@ struct B2bFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -694,7 +743,7 @@ struct B2bFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
|
@ -43,6 +43,7 @@
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "reference/device/tensor_scale_bias.h"
|
||||
@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
//
|
||||
|
||||
B2bInterleavedNonFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < runs; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun
|
||||
//
|
||||
|
||||
B2bInterleavedFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
// batch_count is used as split-k when mode is kGemm according
|
||||
// to the GemmUniversal interface
|
||||
|
||||
int batch_count = 1,
|
||||
|
||||
int64_t batch_stride_A0 = 0,
|
||||
int64_t batch_stride_B0 = 0,
|
||||
int64_t batch_stride_C0 = 0,
|
||||
int64_t batch_stride_B1 = 0,
|
||||
int64_t batch_stride_C1 = 0,
|
||||
int64_t batch_stride_D1 = 0,
|
||||
int64_t batch_stride_Bias0 = 0,
|
||||
int64_t batch_stride_Scale0 = 0,
|
||||
bool relu = true,
|
||||
int warm_ups = 1,
|
||||
int runs = 100) {
|
||||
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
||||
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
||||
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||
|
||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
||||
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
||||
typename B2bGemm::ElementScaleBias,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
||||
ElementAccumulator,
|
||||
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
//Reorder B0
|
||||
cutlass::reorder_column<16>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
// tensor_Bias0_batched.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
mode,
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
|
||||
tensor_B1_reordered.device_ref(),
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
tensor_D1.device_ref(),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0,
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
batch_count,
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
|
||||
// Verify
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||
ElementAccumulator, ElementAccumulator
|
||||
>(
|
||||
problem_size_0,
|
||||
ElementAccumulator(1), //intermediate alpha=1
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_A0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
ElementAccumulator(0), //beta = 0
|
||||
reference_Z0.device_ref(),
|
||||
reference_Z0.device_ref(),
|
||||
ElementAccumulator(0)
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_C0
|
||||
);
|
||||
|
||||
cutlass::reference::device::TensorScaleBiasGemm<
|
||||
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||
> (
|
||||
@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
|
||||
reference_D0.device_ref(),
|
||||
alpha0,
|
||||
tensor_Scale0.device_ref(),
|
||||
tensor_Bias0.device_ref()
|
||||
tensor_Bias0.device_ref(),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_C0,
|
||||
batch_stride_Scale0,
|
||||
batch_stride_Bias0
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
cutlass::reference::device::GemmComplex<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||
ElementCompute, ElementAccumulator
|
||||
>(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
alpha1, //intermediate alpha=1
|
||||
reference_D0.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
tensor_B1.device_ref(),
|
||||
cutlass::ComplexTransform::kNone,
|
||||
beta1, //beta = 0
|
||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||
reference_D1.device_ref()
|
||||
reference_D1.device_ref(),
|
||||
ElementAccumulator(0),
|
||||
int(batch_count),
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
|
@ -119,8 +119,6 @@ template <
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// If true, kernel supports split-K with serial reduction
|
||||
bool SplitKSerial = false,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
@ -154,7 +152,6 @@ class B2bGemm {
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
@ -184,7 +181,6 @@ class B2bGemm {
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator,
|
||||
SmemAccumulator
|
||||
>::B2bGemmKernel;
|
||||
@ -196,6 +192,7 @@ class B2bGemm {
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size_0;
|
||||
GemmCoord problem_size_1;
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
@ -206,9 +203,16 @@ class B2bGemm {
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
typename EpilogueOutputOp0::Params epilogue0;
|
||||
typename EpilogueOutputOp1::Params epilogue1;
|
||||
int split_k_slices;
|
||||
int batch_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -216,13 +220,14 @@ class B2bGemm {
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
|
||||
Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmUniversalMode mode_,
|
||||
GemmCoord problem_size_0_,
|
||||
GemmCoord problem_size_1_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
@ -233,12 +238,20 @@ class B2bGemm {
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
||||
int64_t batch_stride_A0_,
|
||||
int64_t batch_stride_B0_,
|
||||
int64_t batch_stride_B1_,
|
||||
int64_t batch_stride_C1_,
|
||||
int64_t batch_stride_D1_,
|
||||
int64_t batch_stride_Bias0_,
|
||||
int64_t batch_stride_Scale0_,
|
||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
||||
typename EpilogueOutputOp0::Params(),
|
||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
||||
typename EpilogueOutputOp1::Params(),
|
||||
int split_k_slices_ = 1
|
||||
int batch_count_ = 1
|
||||
):
|
||||
mode(mode_),
|
||||
problem_size_0(problem_size_0_),
|
||||
problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_),
|
||||
@ -249,9 +262,16 @@ class B2bGemm {
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
batch_stride_A0(batch_stride_A0_),
|
||||
batch_stride_B0(batch_stride_B0_),
|
||||
batch_stride_B1(batch_stride_B1_),
|
||||
batch_stride_C1(batch_stride_C1_),
|
||||
batch_stride_D1(batch_stride_D1_),
|
||||
batch_stride_Bias0(batch_stride_Bias0_),
|
||||
batch_stride_Scale0(batch_stride_Scale0_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
split_k_slices(split_k_slices_) {
|
||||
batch_count(batch_count_) {
|
||||
|
||||
}
|
||||
};
|
||||
@ -269,10 +289,6 @@ public:
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
Status status = B2bGemmKernel::can_implement(
|
||||
args.problem_size_0,
|
||||
args.problem_size_1,
|
||||
@ -295,20 +311,14 @@ public:
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size_0,
|
||||
args.problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
args.batch_count);
|
||||
|
||||
return bytes;
|
||||
}
|
||||
@ -320,38 +330,17 @@ public:
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size_0,
|
||||
args.problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.split_k_slices);
|
||||
args.batch_count);
|
||||
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
|
||||
// args.problem_size_1,
|
||||
// args.problem_size_1,
|
||||
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
|
||||
// args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
size_t bytes = get_workspace_size(args);
|
||||
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if (args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
// args.batch_count);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename B2bGemmKernel::Params{
|
||||
args.mode,
|
||||
args.problem_size_0,
|
||||
args.problem_size_1,
|
||||
grid_shape,
|
||||
@ -363,6 +352,13 @@ public:
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1,
|
||||
args.batch_stride_A0,
|
||||
args.batch_stride_B0,
|
||||
args.batch_stride_B1,
|
||||
args.batch_stride_C1,
|
||||
args.batch_stride_D1,
|
||||
args.batch_stride_Bias0,
|
||||
args.batch_stride_Scale0,
|
||||
args.epilogue0,
|
||||
args.epilogue1,
|
||||
static_cast<int *>(workspace),
|
||||
@ -373,12 +369,6 @@ public:
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
||||
@ -430,12 +420,12 @@ public:
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
bool passed = fusedGemm.run(
|
||||
gemm_s8_sm80_problem_size_0,
|
||||
gemm_s8_sm80_problem_size_1,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool run_fused_gemm_s8_sm80_rf_res_batch() {
|
||||
|
||||
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
|
||||
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
ElementCompute alpha0 = ElementCompute(1);
|
||||
//Fused kernel has built-in bias, setting beta=0
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(1);
|
||||
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
const bool SmemAccumulator = false;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
3,
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
int batch_count = 2;
|
||||
int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
|
||||
int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
|
||||
int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
|
||||
int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
|
||||
int64_t batch_stride_Scale0 = 0;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
|
||||
bool passed = fusedGemm.run(
|
||||
gemm_s8_sm80_problem_size_0,
|
||||
gemm_s8_sm80_problem_size_1,
|
||||
alpha0,
|
||||
beta0,
|
||||
alpha1,
|
||||
beta1,
|
||||
cutlass::gemm::GemmUniversalMode::kBatched,
|
||||
batch_count,
|
||||
batch_stride_A0,
|
||||
batch_stride_B0,
|
||||
batch_stride_C0,
|
||||
batch_stride_B1,
|
||||
batch_stride_C1,
|
||||
batch_stride_D1,
|
||||
batch_stride_Bias0,
|
||||
batch_stride_Scale0
|
||||
);
|
||||
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
std::vector<bool (*)()>funcs = {
|
||||
&run_nonfused_gemm_s8_sm80,
|
||||
&run_fused_gemm_s8_sm80_rf_res
|
||||
&run_fused_gemm_s8_sm80_rf_res,
|
||||
&run_fused_gemm_s8_sm80_rf_res_batch
|
||||
};
|
||||
|
||||
return testRun(80, funcs, "gemm int8 RF residency");
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
8 * InstructionShape::kN / 32,
|
||||
@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
ElementCompute,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||
>;
|
||||
|
||||
|
||||
const bool SmemAccumulator = true;
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
||||
SmemAccumulator,
|
||||
16,
|
||||
16,
|
||||
false,
|
||||
cutlass::arch::OpMultiplyAddSaturate
|
||||
>;
|
||||
|
||||
|
@ -49,10 +49,9 @@ namespace kernel {
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct B2bGemm {
|
||||
|
||||
@ -61,7 +60,17 @@ struct B2bGemm {
|
||||
using OutputOp0 = typename B2bMma::OutputOp;
|
||||
using OutputOp1 = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using ElementA0 = typename B2bMma::IteratorA0::Element;
|
||||
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
|
||||
using ElementB0 = typename B2bMma::IteratorB0::Element;
|
||||
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
|
||||
using ElementB1 = typename B2bMma::IteratorB1::Element;
|
||||
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||
@ -69,6 +78,7 @@ struct B2bGemm {
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmUniversalMode mode;
|
||||
cutlass::gemm::GemmCoord problem_size_0;
|
||||
cutlass::gemm::GemmCoord problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
@ -89,6 +99,13 @@ struct B2bGemm {
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
int64_t batch_stride_A0;
|
||||
int64_t batch_stride_B0;
|
||||
int64_t batch_stride_B1;
|
||||
int64_t batch_stride_C1;
|
||||
int64_t batch_stride_D1;
|
||||
int64_t batch_stride_Bias0;
|
||||
int64_t batch_stride_Scale0;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations_0;
|
||||
int gemm_k_size_0;
|
||||
@ -100,11 +117,12 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmUniversalMode mode,
|
||||
cutlass::gemm::GemmCoord const & problem_size_0,
|
||||
cutlass::gemm::GemmCoord const & problem_size_1,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
@ -116,10 +134,18 @@ struct B2bGemm {
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
||||
int64_t batch_stride_A0,
|
||||
int64_t batch_stride_B0,
|
||||
int64_t batch_stride_B1,
|
||||
int64_t batch_stride_C1,
|
||||
int64_t batch_stride_D1,
|
||||
int64_t batch_stride_Bias0,
|
||||
int64_t batch_stride_Scale0,
|
||||
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
mode(mode),
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
@ -138,6 +164,13 @@ struct B2bGemm {
|
||||
ref_C1(ref_C1),
|
||||
params_D1(ref_D1.layout()),
|
||||
ref_D1(ref_D1),
|
||||
batch_stride_A0(batch_stride_A0),
|
||||
batch_stride_B0(batch_stride_B0),
|
||||
batch_stride_B1(batch_stride_B1),
|
||||
batch_stride_C1(batch_stride_C1),
|
||||
batch_stride_D1(batch_stride_D1),
|
||||
batch_stride_Bias0(batch_stride_Bias0),
|
||||
batch_stride_Scale0(batch_stride_Scale0),
|
||||
output_op_0(output_op_0),
|
||||
output_op_1(output_op_1) {
|
||||
|
||||
@ -163,7 +196,7 @@ struct B2bGemm {
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
B2bGemm() { }
|
||||
B2bGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
@ -223,7 +256,7 @@ struct B2bGemm {
|
||||
|
||||
if(problem_size_0.n() > B2bMma::Shape0::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
|
||||
if(problem_size_1.n() > B2bMma::Shape1::kN)
|
||||
return Status::kErrorInvalidProblem;
|
||||
|
||||
@ -247,37 +280,64 @@ struct B2bGemm {
|
||||
return;
|
||||
}
|
||||
|
||||
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
|
||||
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
|
||||
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
|
||||
|
||||
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
|
||||
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
|
||||
|
||||
int offset_k_0 = 0;
|
||||
int offset_k_1 = 0;
|
||||
|
||||
int problem_size_k_0 = params.problem_size_0.k();
|
||||
int problem_size_k_1 = params.problem_size_1.k();
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
problem_size_k_0 = min(
|
||||
problem_size_k_0,
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
problem_size_k_1 = min(
|
||||
problem_size_k_1,
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
|
||||
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
|
||||
}
|
||||
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
|
||||
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
|
||||
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
|
||||
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
|
||||
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A0{
|
||||
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
offset_k_0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B0{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
offset_k_0,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B1{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_1,
|
||||
offset_k_1,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_0 = min(
|
||||
params.problem_size_0.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_1 = min(
|
||||
params.problem_size_1.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
|
||||
|
||||
// Compute position within threadblock
|
||||
@ -286,26 +346,25 @@ struct B2bGemm {
|
||||
// Construct iterators to A and B operands
|
||||
typename B2bMma::IteratorA0 iterator_A0(
|
||||
params.params_A0,
|
||||
params.ref_A0.data(),
|
||||
ptr_A0,
|
||||
{params.problem_size_0.m(), problem_size_k_0},
|
||||
thread_idx,
|
||||
tb_offset_A0);
|
||||
|
||||
typename B2bMma::IteratorB0 iterator_B0(
|
||||
params.params_B0,
|
||||
params.ref_B0.data(),
|
||||
ptr_B0,
|
||||
{problem_size_k_0, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
tb_offset_B0);
|
||||
|
||||
typename B2bMma::IteratorB1 iterator_B1(
|
||||
params.params_B1,
|
||||
params.ref_B1.data(),
|
||||
ptr_B1,
|
||||
{problem_size_k_1, params.problem_size_1.n()},
|
||||
thread_idx,
|
||||
tb_offset_B1);
|
||||
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
@ -313,7 +372,7 @@ struct B2bGemm {
|
||||
|
||||
// Construct iterators to accumulator scale/bias vector
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||
params.ref_Scale0.data(),
|
||||
ptr_Scale0,
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
@ -323,7 +382,7 @@ struct B2bGemm {
|
||||
);
|
||||
|
||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||
params.ref_Bias0.data(),
|
||||
ptr_Bias0,
|
||||
{1, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
@ -349,11 +408,9 @@ struct B2bGemm {
|
||||
src_accum.clear();
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
}
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
@ -376,23 +433,32 @@ struct B2bGemm {
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
||||
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
if (params.mode == GemmUniversalMode::kGemm) {
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
if (params.grid_tiled_shape.k() > 1) {
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
|
||||
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
params.ref_C1.data(),
|
||||
ptr_C1,
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
@ -401,21 +467,21 @@ struct B2bGemm {
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D1(
|
||||
params.params_D1,
|
||||
params.ref_D1.data(),
|
||||
ptr_D1,
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C1 = iterator_D1;
|
||||
@ -427,14 +493,14 @@ struct B2bGemm {
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
@ -457,4 +523,3 @@ struct B2bGemm {
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
|
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -114,8 +114,6 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Stage accumulator in shared memory
|
||||
@ -161,22 +159,19 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -188,7 +183,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -228,8 +223,6 @@ template <
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
@ -249,7 +242,6 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator
|
||||
> {
|
||||
|
||||
@ -274,7 +266,7 @@ struct DefaultB2bGemm<
|
||||
Operator,
|
||||
EpilogueOutputOp0
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -287,7 +279,7 @@ struct DefaultB2bGemm<
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -323,20 +315,16 @@ template <
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator> {
|
||||
ThreadblockSwizzle, Stages, Operator> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -360,7 +348,7 @@ struct DefaultB2bGemm<
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -396,19 +384,16 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
|
||||
ThreadblockSwizzle, 2, Operator> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -418,7 +403,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -430,7 +415,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -112,22 +112,19 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||
Operator, true> {
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
||||
EpilogueOutputOp1::kCount>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
@ -179,8 +175,6 @@ template <
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator,
|
||||
true
|
||||
> {
|
||||
@ -228,7 +221,7 @@ struct DefaultB2bGemm<
|
||||
false,
|
||||
true
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
|
||||
@ -277,20 +270,17 @@ template <
|
||||
int Stages,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, layout::ColumnMajorInterleaved<InterleavedK>, kAlignmentA,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementB, layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>, int32_t,
|
||||
arch::OpClassTensorOp, arch::Sm80,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, Stages,
|
||||
SplitKSerial, Operator, true> {
|
||||
Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -350,19 +340,16 @@ template <
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
|
||||
ThreadblockSwizzle, 2, Operator, true> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
@ -371,9 +358,9 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm(
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
MatrixCoord output_coord(
|
||||
@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm(
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
|
||||
int kMblock = 4,
|
||||
int kNblock = 4
|
||||
>
|
||||
__global__ void TensorScaleBiasGemmBatched(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRefIn tensor_in, ///< input tensor
|
||||
TensorRefOut tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias, ///< bias tensor
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_tensor_in = 0,
|
||||
int64_t batch_stride_tensor_out = 0,
|
||||
int64_t batch_stride_tensor_scale = 0,
|
||||
int64_t batch_stride_tensor_bias = 0
|
||||
) {
|
||||
|
||||
ConvertOp convert_op;
|
||||
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
||||
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
||||
int batch_idx = blockIdx.z;
|
||||
|
||||
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
|
||||
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
|
||||
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
|
||||
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
|
||||
|
||||
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < kNblock; j++) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMblock; i++) {
|
||||
int row = row_block + i;
|
||||
int col = col_block + j;
|
||||
MatrixCoord coord = MatrixCoord(row, col);
|
||||
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
||||
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, coord.column()});
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, coord.column()});
|
||||
|
||||
tensor_out.at(coord) = convert_op(
|
||||
scale * ScalarType(tensor_in.at(coord)) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
|
||||
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
|
||||
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
|
||||
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename TensorRefIn, ///< Input TensorRef Type
|
||||
typename TensorRefOut, ///< Output TensorRef Type
|
||||
@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d(
|
||||
TensorRefScalar tensor_scale, ///< scale tensor
|
||||
TensorRefScalar tensor_bias ///< bias tensor
|
||||
) {
|
||||
|
||||
|
||||
ConvertOp convert_op;
|
||||
|
||||
int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
|
||||
@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d(
|
||||
int64_t npq = npq_start + m;
|
||||
|
||||
thread_n[m] = int(npq / PQ);
|
||||
|
||||
|
||||
int64_t residual = npq % PQ;
|
||||
thread_p[m] = int(residual / problem_size.Q);
|
||||
thread_q[m] = int(residual % problem_size.Q);
|
||||
@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d(
|
||||
ScalarType scale = alpha;
|
||||
if(tensor_scale.good())
|
||||
scale = tensor_scale.at({0, thread_k});
|
||||
|
||||
|
||||
ScalarType bias = ScalarType(0);
|
||||
if(tensor_bias.good())
|
||||
if(tensor_bias.good())
|
||||
bias = tensor_bias.at({0, thread_k});
|
||||
|
||||
|
||||
tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op(
|
||||
scale * ScalarType(
|
||||
tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})
|
||||
) + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
typename ElementOut, ///< Output Type
|
||||
typename Layout, ///< Layout of input/output tensor
|
||||
typename ScalarType, ///< alpha Type
|
||||
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||
>
|
||||
void TensorScaleBiasGemmBatched(
|
||||
gemm::GemmCoord problem_size,
|
||||
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||
ScalarType alpha, ///< alpha
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
|
||||
int batch_count = 1,
|
||||
int64_t batch_stride_tensor_in = 0,
|
||||
int64_t batch_stride_tensor_out = 0,
|
||||
int64_t batch_stride_tensor_scale = 0,
|
||||
int64_t batch_stride_tensor_bias = 0
|
||||
) {
|
||||
|
||||
int const kMblock = 4;
|
||||
int const kNblock = 4;
|
||||
|
||||
dim3 block(16, 8);
|
||||
dim3 grid(
|
||||
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
||||
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
||||
batch_count % std::numeric_limits<uint16_t>::max()
|
||||
);
|
||||
|
||||
kernel::TensorScaleBiasGemmBatched<
|
||||
TensorRef<ElementIn, Layout>,
|
||||
TensorRef<ElementOut, Layout>,
|
||||
ScalarType,
|
||||
TensorRef<ScalarType, LayoutScaleBias>,
|
||||
ConvertOp,
|
||||
kMblock,
|
||||
kNblock
|
||||
><<< grid, block >>> (
|
||||
problem_size,
|
||||
tensor_in,
|
||||
tensor_out,
|
||||
alpha,
|
||||
tensor_scale,
|
||||
tensor_bias,
|
||||
batch_count,
|
||||
batch_stride_tensor_in,
|
||||
batch_stride_tensor_out,
|
||||
batch_stride_tensor_scale,
|
||||
batch_stride_tensor_bias
|
||||
);
|
||||
}
|
||||
|
||||
/// Apply scale and bias on a tensor
|
||||
template <
|
||||
typename ElementIn, ///< Input Type
|
||||
|
Loading…
Reference in New Issue
Block a user