added support of b2b bmm (#849)
* added support of b2b bmm * fixed arguments and params structures * added batch_count argument * removed SplitKSerial and added new test case with b2b bmm * fixed support of Kbatched and added new test case with batch stride * added batch support for bias and scale * make test * small changes --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
d572cc1aab
commit
4a68cf748e
@ -42,6 +42,7 @@
|
|||||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||||
#include "cutlass/util/reference/device/gemm.h"
|
#include "cutlass/util/reference/device/gemm.h"
|
||||||
|
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||||
|
|
||||||
#include "reference/device/tensor_scale_bias.h"
|
#include "reference/device/tensor_scale_bias.h"
|
||||||
@ -459,6 +460,20 @@ struct B2bFusedGemmRun
|
|||||||
ElementCompute beta0 = ElementCompute(0),
|
ElementCompute beta0 = ElementCompute(0),
|
||||||
ElementCompute alpha1 = ElementCompute(1),
|
ElementCompute alpha1 = ElementCompute(1),
|
||||||
ElementCompute beta1 = ElementCompute(0),
|
ElementCompute beta1 = ElementCompute(0),
|
||||||
|
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
|
||||||
|
// batch_count is used as split-k when mode is kGemm according
|
||||||
|
// to the GemmUniversal interface
|
||||||
|
|
||||||
|
int batch_count = 1,
|
||||||
|
int64_t batch_stride_A0 = 0,
|
||||||
|
int64_t batch_stride_B0 = 0,
|
||||||
|
int64_t batch_stride_C0 = 0,
|
||||||
|
int64_t batch_stride_B1 = 0,
|
||||||
|
int64_t batch_stride_C1 = 0,
|
||||||
|
int64_t batch_stride_D1 = 0,
|
||||||
|
int64_t batch_stride_Bias0 = 0,
|
||||||
|
int64_t batch_stride_Scale0 = 0,
|
||||||
bool relu = true,
|
bool relu = true,
|
||||||
int warm_ups = 1,
|
int warm_ups = 1,
|
||||||
int runs = 100) {
|
int runs = 100) {
|
||||||
@ -467,56 +482,62 @@ struct B2bFusedGemmRun
|
|||||||
// Allocate the GEMM workspace
|
// Allocate the GEMM workspace
|
||||||
//
|
//
|
||||||
|
|
||||||
|
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||||
|
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||||
|
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
||||||
|
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
||||||
|
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementA,
|
typename B2bGemm::ElementA,
|
||||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementB,
|
typename B2bGemm::ElementB,
|
||||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementScaleBias,
|
typename B2bGemm::ElementScaleBias,
|
||||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||||
|
|
||||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementScaleBias,
|
typename B2bGemm::ElementScaleBias,
|
||||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementB,
|
typename B2bGemm::ElementB,
|
||||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
|
||||||
ElementCompute,
|
|
||||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
||||||
|
|
||||||
|
cutlass::HostTensor<
|
||||||
|
typename B2bGemm::ElementC,
|
||||||
|
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
||||||
|
|
||||||
|
|
||||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||||
@ -554,6 +575,7 @@ struct B2bFusedGemmRun
|
|||||||
//
|
//
|
||||||
|
|
||||||
typename B2bGemm::Arguments arguments{
|
typename B2bGemm::Arguments arguments{
|
||||||
|
mode,
|
||||||
problem_size_0,
|
problem_size_0,
|
||||||
problem_size_1,
|
problem_size_1,
|
||||||
tensor_A0.device_ref(),
|
tensor_A0.device_ref(),
|
||||||
@ -564,8 +586,16 @@ struct B2bFusedGemmRun
|
|||||||
tensor_B1.device_ref(),
|
tensor_B1.device_ref(),
|
||||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||||
tensor_D1.device_ref(),
|
tensor_D1.device_ref(),
|
||||||
|
batch_stride_A0,
|
||||||
|
batch_stride_B0,
|
||||||
|
batch_stride_B1,
|
||||||
|
batch_stride_C1,
|
||||||
|
batch_stride_D1,
|
||||||
|
batch_stride_Bias0,
|
||||||
|
batch_stride_Scale0,
|
||||||
{alpha0, beta0},
|
{alpha0, beta0},
|
||||||
{alpha1, beta1},
|
{alpha1, beta1},
|
||||||
|
batch_count,
|
||||||
};
|
};
|
||||||
|
|
||||||
B2bGemm b2b_gemm_op;
|
B2bGemm b2b_gemm_op;
|
||||||
@ -618,32 +648,31 @@ struct B2bFusedGemmRun
|
|||||||
// Verify
|
// Verify
|
||||||
//
|
//
|
||||||
|
|
||||||
cutlass::reference::device::Gemm<
|
cutlass::reference::device::GemmComplex<
|
||||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||||
ElementAccumulator, ElementAccumulator>
|
ElementAccumulator, ElementAccumulator
|
||||||
reference_gemm_0;
|
>(
|
||||||
|
|
||||||
cutlass::reference::device::Gemm<
|
|
||||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
|
||||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
|
||||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
|
||||||
ElementAccumulator, typename B2bGemm::Operator>
|
|
||||||
reference_gemm_1;
|
|
||||||
|
|
||||||
reference_gemm_0(
|
|
||||||
problem_size_0,
|
problem_size_0,
|
||||||
ElementAccumulator(1), //intermediate alpha=1
|
ElementAccumulator(1), //intermediate alpha=1
|
||||||
tensor_A0.device_ref(),
|
tensor_A0.device_ref(),
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
tensor_B0.device_ref(),
|
tensor_B0.device_ref(),
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
ElementAccumulator(0), //beta = 0
|
ElementAccumulator(0), //beta = 0
|
||||||
reference_Z0.device_ref(),
|
reference_Z0.device_ref(),
|
||||||
reference_Z0.device_ref(),
|
reference_Z0.device_ref(),
|
||||||
ElementAccumulator(0)
|
ElementAccumulator(0),
|
||||||
|
int(batch_count),
|
||||||
|
batch_stride_A0,
|
||||||
|
batch_stride_B0,
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_C0
|
||||||
);
|
);
|
||||||
|
|
||||||
cutlass::reference::device::TensorScaleBiasGemm<
|
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
||||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||||
> (
|
> (
|
||||||
@ -652,25 +681,45 @@ struct B2bFusedGemmRun
|
|||||||
reference_D0.device_ref(),
|
reference_D0.device_ref(),
|
||||||
alpha0,
|
alpha0,
|
||||||
tensor_Scale0.device_ref(),
|
tensor_Scale0.device_ref(),
|
||||||
tensor_Bias0.device_ref()
|
tensor_Bias0.device_ref(),
|
||||||
|
int(batch_count),
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_Scale0,
|
||||||
|
batch_stride_Bias0
|
||||||
);
|
);
|
||||||
|
|
||||||
if(relu) {
|
if(relu) {
|
||||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||||
}
|
}
|
||||||
|
|
||||||
reference_gemm_1(
|
cutlass::reference::device::GemmComplex<
|
||||||
|
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||||
|
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||||
|
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||||
|
ElementCompute, ElementAccumulator
|
||||||
|
>(
|
||||||
problem_size_1,
|
problem_size_1,
|
||||||
alpha1,
|
alpha1, //intermediate alpha=1
|
||||||
reference_D0.device_ref(),
|
reference_D0.device_ref(),
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
tensor_B1.device_ref(),
|
tensor_B1.device_ref(),
|
||||||
beta1,
|
cutlass::ComplexTransform::kNone,
|
||||||
|
beta1, //beta = 0
|
||||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||||
reference_D1.device_ref()
|
reference_D1.device_ref(),
|
||||||
|
ElementAccumulator(0),
|
||||||
|
int(batch_count),
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_B1,
|
||||||
|
batch_stride_C1,
|
||||||
|
batch_stride_D1
|
||||||
);
|
);
|
||||||
|
|
||||||
if(relu) {
|
if(relu) {
|
||||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaDeviceSynchronize();
|
cudaDeviceSynchronize();
|
||||||
reference_D0.sync_host();
|
reference_D0.sync_host();
|
||||||
reference_D1.sync_host();
|
reference_D1.sync_host();
|
||||||
|
|||||||
@ -43,6 +43,7 @@
|
|||||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||||
#include "cutlass/util/host_reorder.h"
|
#include "cutlass/util/host_reorder.h"
|
||||||
#include "cutlass/util/reference/device/gemm.h"
|
#include "cutlass/util/reference/device/gemm.h"
|
||||||
|
#include "cutlass/util/reference/device/gemm_complex.h"
|
||||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||||
|
|
||||||
#include "reference/device/tensor_scale_bias.h"
|
#include "reference/device/tensor_scale_bias.h"
|
||||||
@ -194,7 +195,6 @@ struct B2bInterleavedNonFusedGemmRun
|
|||||||
typename Gemm1::ElementC,
|
typename Gemm1::ElementC,
|
||||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||||
|
|
||||||
|
|
||||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||||
@ -476,6 +476,21 @@ struct B2bInterleavedFusedGemmRun
|
|||||||
ElementCompute beta0 = ElementCompute(0),
|
ElementCompute beta0 = ElementCompute(0),
|
||||||
ElementCompute alpha1 = ElementCompute(1),
|
ElementCompute alpha1 = ElementCompute(1),
|
||||||
ElementCompute beta1 = ElementCompute(0),
|
ElementCompute beta1 = ElementCompute(0),
|
||||||
|
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
|
||||||
|
// batch_count is used as split-k when mode is kGemm according
|
||||||
|
// to the GemmUniversal interface
|
||||||
|
|
||||||
|
int batch_count = 1,
|
||||||
|
|
||||||
|
int64_t batch_stride_A0 = 0,
|
||||||
|
int64_t batch_stride_B0 = 0,
|
||||||
|
int64_t batch_stride_C0 = 0,
|
||||||
|
int64_t batch_stride_B1 = 0,
|
||||||
|
int64_t batch_stride_C1 = 0,
|
||||||
|
int64_t batch_stride_D1 = 0,
|
||||||
|
int64_t batch_stride_Bias0 = 0,
|
||||||
|
int64_t batch_stride_Scale0 = 0,
|
||||||
bool relu = true,
|
bool relu = true,
|
||||||
int warm_ups = 1,
|
int warm_ups = 1,
|
||||||
int runs = 100) {
|
int runs = 100) {
|
||||||
@ -484,64 +499,70 @@ struct B2bInterleavedFusedGemmRun
|
|||||||
// Allocate the GEMM workspace
|
// Allocate the GEMM workspace
|
||||||
//
|
//
|
||||||
|
|
||||||
|
cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||||
|
cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k());
|
||||||
|
cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k());
|
||||||
|
cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k());
|
||||||
|
cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementA,
|
typename B2bGemm::ElementA,
|
||||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementB,
|
typename B2bGemm::ElementB,
|
||||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementB,
|
typename B2bGemm::ElementB,
|
||||||
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementScaleBias,
|
typename B2bGemm::ElementScaleBias,
|
||||||
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
typename B2bGemm::LayoutScaleBias> tensor_Scale0;
|
||||||
|
|
||||||
if(alpha0 == ElementCompute(0)) //per-channel scale
|
if(alpha0 == ElementCompute(0)) //per-channel scale
|
||||||
tensor_Scale0.resize({1, problem_size_0.n()});
|
tensor_Scale0.resize({1, batch_count * problem_size_0.n()});
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementScaleBias,
|
typename B2bGemm::ElementScaleBias,
|
||||||
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()});
|
typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()});
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn());
|
typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
typename B2bGemm::LayoutC> reference_D0(CoordC0.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementB,
|
typename B2bGemm::ElementB,
|
||||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementB,
|
typename B2bGemm::ElementB,
|
||||||
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()});
|
typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()});
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn());
|
||||||
|
|
||||||
cutlass::HostTensor<
|
cutlass::HostTensor<
|
||||||
typename B2bGemm::ElementC,
|
typename B2bGemm::ElementC,
|
||||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
typename B2bGemm::LayoutC> reference_D1(CoordC1.mn());
|
||||||
|
|
||||||
|
|
||||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||||
@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun
|
|||||||
|
|
||||||
//Reorder B0
|
//Reorder B0
|
||||||
cutlass::reorder_column<16>(
|
cutlass::reorder_column<16>(
|
||||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0);
|
||||||
cutlass::reorder_column<InterleavedK_>(
|
cutlass::reorder_column<InterleavedK_>(
|
||||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1);
|
||||||
|
|
||||||
cutlass::reference::host::TensorFill(
|
cutlass::reference::host::TensorFill(
|
||||||
tensor_D1.host_view());
|
tensor_D1.host_view());
|
||||||
@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun
|
|||||||
tensor_D1.sync_device();
|
tensor_D1.sync_device();
|
||||||
reference_D0.sync_device();
|
reference_D0.sync_device();
|
||||||
reference_D1.sync_device();
|
reference_D1.sync_device();
|
||||||
|
// tensor_Bias0_batched.sync_device();
|
||||||
|
|
||||||
//
|
//
|
||||||
// Initialize the GEMM operator
|
// Initialize the GEMM operator
|
||||||
//
|
//
|
||||||
|
|
||||||
typename B2bGemm::Arguments arguments{
|
typename B2bGemm::Arguments arguments{
|
||||||
|
mode,
|
||||||
problem_size_0,
|
problem_size_0,
|
||||||
problem_size_1,
|
problem_size_1,
|
||||||
tensor_A0.device_ref(),
|
tensor_A0.device_ref(),
|
||||||
@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun
|
|||||||
tensor_B1_reordered.device_ref(),
|
tensor_B1_reordered.device_ref(),
|
||||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||||
tensor_D1.device_ref(),
|
tensor_D1.device_ref(),
|
||||||
|
batch_stride_A0,
|
||||||
|
batch_stride_B0,
|
||||||
|
batch_stride_B1,
|
||||||
|
batch_stride_C1,
|
||||||
|
batch_stride_D1,
|
||||||
|
batch_stride_Bias0,
|
||||||
|
batch_stride_Scale0,
|
||||||
{alpha0, beta0},
|
{alpha0, beta0},
|
||||||
{alpha1, beta1},
|
{alpha1, beta1},
|
||||||
|
batch_count,
|
||||||
};
|
};
|
||||||
|
|
||||||
B2bGemm b2b_gemm_op;
|
B2bGemm b2b_gemm_op;
|
||||||
@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun
|
|||||||
// Verify
|
// Verify
|
||||||
//
|
//
|
||||||
|
|
||||||
cutlass::reference::device::Gemm<
|
cutlass::reference::device::GemmComplex<
|
||||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||||
ElementAccumulator, typename B2bGemm::LayoutC,
|
ElementAccumulator, typename B2bGemm::LayoutC,
|
||||||
ElementAccumulator, ElementAccumulator>
|
ElementAccumulator, ElementAccumulator
|
||||||
reference_gemm_0;
|
>(
|
||||||
|
|
||||||
cutlass::reference::device::Gemm<
|
|
||||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
|
||||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
|
||||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
|
||||||
ElementAccumulator, typename B2bGemm::Operator>
|
|
||||||
reference_gemm_1;
|
|
||||||
|
|
||||||
reference_gemm_0(
|
|
||||||
problem_size_0,
|
problem_size_0,
|
||||||
ElementAccumulator(1), //intermediate alpha=1
|
ElementAccumulator(1), //intermediate alpha=1
|
||||||
tensor_A0.device_ref(),
|
tensor_A0.device_ref(),
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
tensor_B0.device_ref(),
|
tensor_B0.device_ref(),
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
ElementAccumulator(0), //beta = 0
|
ElementAccumulator(0), //beta = 0
|
||||||
reference_Z0.device_ref(),
|
reference_Z0.device_ref(),
|
||||||
reference_Z0.device_ref(),
|
reference_Z0.device_ref(),
|
||||||
ElementAccumulator(0)
|
ElementAccumulator(0),
|
||||||
|
int(batch_count),
|
||||||
|
batch_stride_A0,
|
||||||
|
batch_stride_B0,
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_C0
|
||||||
);
|
);
|
||||||
|
|
||||||
cutlass::reference::device::TensorScaleBiasGemm<
|
cutlass::reference::device::TensorScaleBiasGemmBatched<
|
||||||
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||||
ElementCompute, typename B2bGemm::LayoutScaleBias
|
ElementCompute, typename B2bGemm::LayoutScaleBias
|
||||||
> (
|
> (
|
||||||
@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun
|
|||||||
reference_D0.device_ref(),
|
reference_D0.device_ref(),
|
||||||
alpha0,
|
alpha0,
|
||||||
tensor_Scale0.device_ref(),
|
tensor_Scale0.device_ref(),
|
||||||
tensor_Bias0.device_ref()
|
tensor_Bias0.device_ref(),
|
||||||
|
int(batch_count),
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_Scale0,
|
||||||
|
batch_stride_Bias0
|
||||||
);
|
);
|
||||||
|
|
||||||
if(relu) {
|
if(relu) {
|
||||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||||
}
|
}
|
||||||
|
|
||||||
reference_gemm_1(
|
cutlass::reference::device::GemmComplex<
|
||||||
|
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||||
|
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||||
|
typename B2bGemm::ElementC, typename B2bGemm::LayoutC,
|
||||||
|
ElementCompute, ElementAccumulator
|
||||||
|
>(
|
||||||
problem_size_1,
|
problem_size_1,
|
||||||
alpha1,
|
alpha1, //intermediate alpha=1
|
||||||
reference_D0.device_ref(),
|
reference_D0.device_ref(),
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
tensor_B1.device_ref(),
|
tensor_B1.device_ref(),
|
||||||
beta1,
|
cutlass::ComplexTransform::kNone,
|
||||||
|
beta1, //beta = 0
|
||||||
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
{tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)},
|
||||||
reference_D1.device_ref()
|
reference_D1.device_ref(),
|
||||||
|
ElementAccumulator(0),
|
||||||
|
int(batch_count),
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_B1,
|
||||||
|
batch_stride_C1,
|
||||||
|
batch_stride_D1
|
||||||
);
|
);
|
||||||
|
|
||||||
if(relu) {
|
if(relu) {
|
||||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaDeviceSynchronize();
|
cudaDeviceSynchronize();
|
||||||
reference_D0.sync_host();
|
reference_D0.sync_host();
|
||||||
reference_D1.sync_host();
|
reference_D1.sync_host();
|
||||||
|
|||||||
@ -119,8 +119,6 @@ template <
|
|||||||
int AlignmentB =
|
int AlignmentB =
|
||||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||||
/// If true, kernel supports split-K with serial reduction
|
|
||||||
bool SplitKSerial = false,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator_ = typename DefaultGemmConfiguration<
|
typename Operator_ = typename DefaultGemmConfiguration<
|
||||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||||
@ -154,7 +152,6 @@ class B2bGemm {
|
|||||||
static int const kAlignmentA = AlignmentA;
|
static int const kAlignmentA = AlignmentA;
|
||||||
static int const kAlignmentB = AlignmentB;
|
static int const kAlignmentB = AlignmentB;
|
||||||
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||||||
static bool const kSplitKSerial = SplitKSerial;
|
|
||||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||||
|
|
||||||
@ -184,7 +181,6 @@ class B2bGemm {
|
|||||||
EpilogueOutputOp1,
|
EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
kStages,
|
kStages,
|
||||||
kSplitKSerial,
|
|
||||||
Operator,
|
Operator,
|
||||||
SmemAccumulator
|
SmemAccumulator
|
||||||
>::B2bGemmKernel;
|
>::B2bGemmKernel;
|
||||||
@ -196,6 +192,7 @@ class B2bGemm {
|
|||||||
// Data members
|
// Data members
|
||||||
//
|
//
|
||||||
|
|
||||||
|
GemmUniversalMode mode;
|
||||||
GemmCoord problem_size_0;
|
GemmCoord problem_size_0;
|
||||||
GemmCoord problem_size_1;
|
GemmCoord problem_size_1;
|
||||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||||
@ -206,9 +203,16 @@ class B2bGemm {
|
|||||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||||
TensorRef<ElementC, LayoutC> ref_D1;
|
TensorRef<ElementC, LayoutC> ref_D1;
|
||||||
|
int64_t batch_stride_A0;
|
||||||
|
int64_t batch_stride_B0;
|
||||||
|
int64_t batch_stride_B1;
|
||||||
|
int64_t batch_stride_C1;
|
||||||
|
int64_t batch_stride_D1;
|
||||||
|
int64_t batch_stride_Bias0;
|
||||||
|
int64_t batch_stride_Scale0;
|
||||||
typename EpilogueOutputOp0::Params epilogue0;
|
typename EpilogueOutputOp0::Params epilogue0;
|
||||||
typename EpilogueOutputOp1::Params epilogue1;
|
typename EpilogueOutputOp1::Params epilogue1;
|
||||||
int split_k_slices;
|
int batch_count;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Methods
|
// Methods
|
||||||
@ -216,13 +220,14 @@ class B2bGemm {
|
|||||||
|
|
||||||
/// Default ctor
|
/// Default ctor
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
|
Arguments(): mode(mode), problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), batch_count(1) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs an Arguments structure
|
/// Constructs an Arguments structure
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Arguments(
|
Arguments(
|
||||||
|
GemmUniversalMode mode_,
|
||||||
GemmCoord problem_size_0_,
|
GemmCoord problem_size_0_,
|
||||||
GemmCoord problem_size_1_,
|
GemmCoord problem_size_1_,
|
||||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||||
@ -233,12 +238,20 @@ class B2bGemm {
|
|||||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||||
|
int64_t batch_stride_A0_,
|
||||||
|
int64_t batch_stride_B0_,
|
||||||
|
int64_t batch_stride_B1_,
|
||||||
|
int64_t batch_stride_C1_,
|
||||||
|
int64_t batch_stride_D1_,
|
||||||
|
int64_t batch_stride_Bias0_,
|
||||||
|
int64_t batch_stride_Scale0_,
|
||||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
typename EpilogueOutputOp0::Params epilogue0_ =
|
||||||
typename EpilogueOutputOp0::Params(),
|
typename EpilogueOutputOp0::Params(),
|
||||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
typename EpilogueOutputOp1::Params epilogue1_ =
|
||||||
typename EpilogueOutputOp1::Params(),
|
typename EpilogueOutputOp1::Params(),
|
||||||
int split_k_slices_ = 1
|
int batch_count_ = 1
|
||||||
):
|
):
|
||||||
|
mode(mode_),
|
||||||
problem_size_0(problem_size_0_),
|
problem_size_0(problem_size_0_),
|
||||||
problem_size_1(problem_size_1_),
|
problem_size_1(problem_size_1_),
|
||||||
ref_A0(ref_A0_),
|
ref_A0(ref_A0_),
|
||||||
@ -249,9 +262,16 @@ class B2bGemm {
|
|||||||
ref_B1(ref_B1_),
|
ref_B1(ref_B1_),
|
||||||
ref_C1(ref_C1_),
|
ref_C1(ref_C1_),
|
||||||
ref_D1(ref_D1_),
|
ref_D1(ref_D1_),
|
||||||
|
batch_stride_A0(batch_stride_A0_),
|
||||||
|
batch_stride_B0(batch_stride_B0_),
|
||||||
|
batch_stride_B1(batch_stride_B1_),
|
||||||
|
batch_stride_C1(batch_stride_C1_),
|
||||||
|
batch_stride_D1(batch_stride_D1_),
|
||||||
|
batch_stride_Bias0(batch_stride_Bias0_),
|
||||||
|
batch_stride_Scale0(batch_stride_Scale0_),
|
||||||
epilogue0(epilogue0_),
|
epilogue0(epilogue0_),
|
||||||
epilogue1(epilogue1_),
|
epilogue1(epilogue1_),
|
||||||
split_k_slices(split_k_slices_) {
|
batch_count(batch_count_) {
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -269,10 +289,6 @@ public:
|
|||||||
/// Determines whether the GEMM can execute the given problem.
|
/// Determines whether the GEMM can execute the given problem.
|
||||||
static Status can_implement(Arguments const &args) {
|
static Status can_implement(Arguments const &args) {
|
||||||
|
|
||||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
|
||||||
return Status::kErrorInvalidProblem;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status status = B2bGemmKernel::can_implement(
|
Status status = B2bGemmKernel::can_implement(
|
||||||
args.problem_size_0,
|
args.problem_size_0,
|
||||||
args.problem_size_1,
|
args.problem_size_1,
|
||||||
@ -302,13 +318,7 @@ public:
|
|||||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||||
args.problem_size_0,
|
args.problem_size_0,
|
||||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||||
args.split_k_slices);
|
args.batch_count);
|
||||||
|
|
||||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
|
||||||
|
|
||||||
|
|
||||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
@ -322,36 +332,15 @@ public:
|
|||||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||||
args.problem_size_0,
|
args.problem_size_0,
|
||||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||||
args.split_k_slices);
|
args.batch_count);
|
||||||
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
|
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
|
||||||
// args.problem_size_1,
|
// args.problem_size_1,
|
||||||
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
|
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
|
||||||
// args.split_k_slices);
|
// args.batch_count);
|
||||||
|
|
||||||
if (kSplitKSerial) {
|
|
||||||
if (args.split_k_slices > 1) {
|
|
||||||
if (!workspace) {
|
|
||||||
return Status::kErrorWorkspaceNull;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t bytes = get_workspace_size(args);
|
|
||||||
|
|
||||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
|
||||||
|
|
||||||
if (result != cudaSuccess) {
|
|
||||||
return Status::kErrorInternal;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
if (args.split_k_slices > 1) {
|
|
||||||
return Status::kErrorInvalidProblem;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the Params structure
|
// Initialize the Params structure
|
||||||
params_ = typename B2bGemmKernel::Params{
|
params_ = typename B2bGemmKernel::Params{
|
||||||
|
args.mode,
|
||||||
args.problem_size_0,
|
args.problem_size_0,
|
||||||
args.problem_size_1,
|
args.problem_size_1,
|
||||||
grid_shape,
|
grid_shape,
|
||||||
@ -363,6 +352,13 @@ public:
|
|||||||
args.ref_B1.non_const_ref(),
|
args.ref_B1.non_const_ref(),
|
||||||
args.ref_C1.non_const_ref(),
|
args.ref_C1.non_const_ref(),
|
||||||
args.ref_D1,
|
args.ref_D1,
|
||||||
|
args.batch_stride_A0,
|
||||||
|
args.batch_stride_B0,
|
||||||
|
args.batch_stride_B1,
|
||||||
|
args.batch_stride_C1,
|
||||||
|
args.batch_stride_D1,
|
||||||
|
args.batch_stride_Bias0,
|
||||||
|
args.batch_stride_Scale0,
|
||||||
args.epilogue0,
|
args.epilogue0,
|
||||||
args.epilogue1,
|
args.epilogue1,
|
||||||
static_cast<int *>(workspace),
|
static_cast<int *>(workspace),
|
||||||
@ -374,12 +370,6 @@ public:
|
|||||||
/// Lightweight update given a subset of arguments
|
/// Lightweight update given a subset of arguments
|
||||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||||
|
|
||||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
|
||||||
if (!workspace) {
|
|
||||||
return Status::kErrorWorkspaceNull;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
params_.ref_A0.reset(args.ref_A0.non_const_ref().data());
|
||||||
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
params_.ref_B0.reset(args.ref_B0.non_const_ref().data());
|
||||||
params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
|
params_.ref_C0.reset(args.ref_C0.non_const_ref().data());
|
||||||
|
|||||||
@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
|||||||
SmemAccumulator,
|
SmemAccumulator,
|
||||||
16,
|
16,
|
||||||
16,
|
16,
|
||||||
false,
|
|
||||||
cutlass::arch::OpMultiplyAddSaturate
|
cutlass::arch::OpMultiplyAddSaturate
|
||||||
>;
|
>;
|
||||||
|
|
||||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||||
|
|
||||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
|
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n";
|
||||||
bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1);
|
bool passed = fusedGemm.run(
|
||||||
|
gemm_s8_sm80_problem_size_0,
|
||||||
|
gemm_s8_sm80_problem_size_1,
|
||||||
|
alpha0,
|
||||||
|
beta0,
|
||||||
|
alpha1,
|
||||||
|
beta1
|
||||||
|
);
|
||||||
|
|
||||||
if(passed)
|
if(passed)
|
||||||
std::cout << "Pass\n";
|
std::cout << "Pass\n";
|
||||||
else
|
else
|
||||||
@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() {
|
|||||||
return passed;
|
return passed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool run_fused_gemm_s8_sm80_rf_res_batch() {
|
||||||
|
|
||||||
|
|
||||||
|
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128);
|
||||||
|
cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64);
|
||||||
|
|
||||||
|
using ElementOutput = int8_t;
|
||||||
|
using ElementAccumulator = int32_t;
|
||||||
|
using ElementCompute = float;
|
||||||
|
|
||||||
|
ElementCompute alpha0 = ElementCompute(1);
|
||||||
|
//Fused kernel has built-in bias, setting beta=0
|
||||||
|
ElementCompute beta0 = ElementCompute(0);
|
||||||
|
ElementCompute alpha1 = ElementCompute(1);
|
||||||
|
ElementCompute beta1 = ElementCompute(1); //beta=1 for bias
|
||||||
|
|
||||||
|
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||||
|
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||||
|
|
||||||
|
using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
|
|
||||||
|
using EpilogueOutputOp0 =
|
||||||
|
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||||
|
ElementOutput,
|
||||||
|
8 * InstructionShape::kN / 32,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementCompute,
|
||||||
|
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
||||||
|
>;
|
||||||
|
|
||||||
|
using EpilogueOutputOp1 =
|
||||||
|
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||||
|
ElementOutput,
|
||||||
|
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementCompute,
|
||||||
|
cutlass::epilogue::thread::ScaleType::NoBetaScaling
|
||||||
|
>;
|
||||||
|
|
||||||
|
const bool SmemAccumulator = false;
|
||||||
|
|
||||||
|
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||||
|
int8_t,
|
||||||
|
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||||
|
int8_t,
|
||||||
|
cutlass::layout::RowMajorInterleaved<32>,
|
||||||
|
ElementOutput,
|
||||||
|
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||||
|
ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::arch::Sm80,
|
||||||
|
ThreadblockShape0,
|
||||||
|
ThreadblockShape1,
|
||||||
|
WarpShape0,
|
||||||
|
WarpShape1,
|
||||||
|
InstructionShape,
|
||||||
|
EpilogueOutputOp0,
|
||||||
|
EpilogueOutputOp1,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
3,
|
||||||
|
SmemAccumulator,
|
||||||
|
16,
|
||||||
|
16,
|
||||||
|
cutlass::arch::OpMultiplyAddSaturate
|
||||||
|
>;
|
||||||
|
|
||||||
|
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||||
|
|
||||||
|
int batch_count = 2;
|
||||||
|
int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k();
|
||||||
|
int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
|
||||||
|
int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n();
|
||||||
|
int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n();
|
||||||
|
int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n();
|
||||||
|
int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n();
|
||||||
|
int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n();
|
||||||
|
int64_t batch_stride_Scale0 = 0;
|
||||||
|
|
||||||
|
std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n";
|
||||||
|
bool passed = fusedGemm.run(
|
||||||
|
gemm_s8_sm80_problem_size_0,
|
||||||
|
gemm_s8_sm80_problem_size_1,
|
||||||
|
alpha0,
|
||||||
|
beta0,
|
||||||
|
alpha1,
|
||||||
|
beta1,
|
||||||
|
cutlass::gemm::GemmUniversalMode::kBatched,
|
||||||
|
batch_count,
|
||||||
|
batch_stride_A0,
|
||||||
|
batch_stride_B0,
|
||||||
|
batch_stride_C0,
|
||||||
|
batch_stride_B1,
|
||||||
|
batch_stride_C1,
|
||||||
|
batch_stride_D1,
|
||||||
|
batch_stride_Bias0,
|
||||||
|
batch_stride_Scale0
|
||||||
|
);
|
||||||
|
|
||||||
|
if(passed)
|
||||||
|
std::cout << "Pass\n";
|
||||||
|
else
|
||||||
|
std::cout << "Fail\n";
|
||||||
|
|
||||||
|
return passed;
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
|
||||||
std::vector<bool (*)()>funcs = {
|
std::vector<bool (*)()>funcs = {
|
||||||
&run_nonfused_gemm_s8_sm80,
|
&run_nonfused_gemm_s8_sm80,
|
||||||
&run_fused_gemm_s8_sm80_rf_res
|
&run_fused_gemm_s8_sm80_rf_res,
|
||||||
|
&run_fused_gemm_s8_sm80_rf_res_batch
|
||||||
};
|
};
|
||||||
|
|
||||||
return testRun(80, funcs, "gemm int8 RF residency");
|
return testRun(80, funcs, "gemm int8 RF residency");
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() {
|
|||||||
SmemAccumulator,
|
SmemAccumulator,
|
||||||
16,
|
16,
|
||||||
16,
|
16,
|
||||||
false,
|
|
||||||
cutlass::arch::OpMultiplyAddSaturate
|
cutlass::arch::OpMultiplyAddSaturate
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
|||||||
@ -51,8 +51,7 @@ namespace kernel {
|
|||||||
template <
|
template <
|
||||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||||
typename Epilogue_, ///! Epilogue
|
typename Epilogue_, ///! Epilogue
|
||||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
|
||||||
>
|
>
|
||||||
struct B2bGemm {
|
struct B2bGemm {
|
||||||
|
|
||||||
@ -61,7 +60,17 @@ struct B2bGemm {
|
|||||||
using OutputOp0 = typename B2bMma::OutputOp;
|
using OutputOp0 = typename B2bMma::OutputOp;
|
||||||
using OutputOp1 = typename Epilogue::OutputOp;
|
using OutputOp1 = typename Epilogue::OutputOp;
|
||||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||||
static bool const kSplitKSerial = SplitKSerial;
|
|
||||||
|
using ElementA0 = typename B2bMma::IteratorA0::Element;
|
||||||
|
using LayoutA0 = typename B2bMma::IteratorA0::Layout;
|
||||||
|
using ElementB0 = typename B2bMma::IteratorB0::Element;
|
||||||
|
using LayoutB0 = typename B2bMma::IteratorB0::Layout;
|
||||||
|
using ElementB1 = typename B2bMma::IteratorB1::Element;
|
||||||
|
using LayoutB1 = typename B2bMma::IteratorB1::Layout;
|
||||||
|
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||||
|
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||||
|
|
||||||
|
using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element;
|
||||||
|
|
||||||
/// Warp count (concept: GemmShape)
|
/// Warp count (concept: GemmShape)
|
||||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||||
@ -69,6 +78,7 @@ struct B2bGemm {
|
|||||||
|
|
||||||
/// Parameters structure
|
/// Parameters structure
|
||||||
struct Params {
|
struct Params {
|
||||||
|
cutlass::gemm::GemmUniversalMode mode;
|
||||||
cutlass::gemm::GemmCoord problem_size_0;
|
cutlass::gemm::GemmCoord problem_size_0;
|
||||||
cutlass::gemm::GemmCoord problem_size_1;
|
cutlass::gemm::GemmCoord problem_size_1;
|
||||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||||
@ -89,6 +99,13 @@ struct B2bGemm {
|
|||||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||||
typename OutputOp0::Params output_op_0;
|
typename OutputOp0::Params output_op_0;
|
||||||
typename OutputOp1::Params output_op_1;
|
typename OutputOp1::Params output_op_1;
|
||||||
|
int64_t batch_stride_A0;
|
||||||
|
int64_t batch_stride_B0;
|
||||||
|
int64_t batch_stride_B1;
|
||||||
|
int64_t batch_stride_C1;
|
||||||
|
int64_t batch_stride_D1;
|
||||||
|
int64_t batch_stride_Bias0;
|
||||||
|
int64_t batch_stride_Scale0;
|
||||||
int *semaphore;
|
int *semaphore;
|
||||||
int gemm_k_iterations_0;
|
int gemm_k_iterations_0;
|
||||||
int gemm_k_size_0;
|
int gemm_k_size_0;
|
||||||
@ -100,11 +117,12 @@ struct B2bGemm {
|
|||||||
//
|
//
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
Params(): mode(mode), swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Params(
|
Params(
|
||||||
|
cutlass::gemm::GemmUniversalMode mode,
|
||||||
cutlass::gemm::GemmCoord const & problem_size_0,
|
cutlass::gemm::GemmCoord const & problem_size_0,
|
||||||
cutlass::gemm::GemmCoord const & problem_size_1,
|
cutlass::gemm::GemmCoord const & problem_size_1,
|
||||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||||
@ -116,10 +134,18 @@ struct B2bGemm {
|
|||||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
||||||
|
int64_t batch_stride_A0,
|
||||||
|
int64_t batch_stride_B0,
|
||||||
|
int64_t batch_stride_B1,
|
||||||
|
int64_t batch_stride_C1,
|
||||||
|
int64_t batch_stride_D1,
|
||||||
|
int64_t batch_stride_Bias0,
|
||||||
|
int64_t batch_stride_Scale0,
|
||||||
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
||||||
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
||||||
int *workspace = nullptr
|
int *workspace = nullptr
|
||||||
):
|
):
|
||||||
|
mode(mode),
|
||||||
problem_size_0(problem_size_0),
|
problem_size_0(problem_size_0),
|
||||||
problem_size_1(problem_size_1),
|
problem_size_1(problem_size_1),
|
||||||
grid_tiled_shape(grid_tiled_shape),
|
grid_tiled_shape(grid_tiled_shape),
|
||||||
@ -138,6 +164,13 @@ struct B2bGemm {
|
|||||||
ref_C1(ref_C1),
|
ref_C1(ref_C1),
|
||||||
params_D1(ref_D1.layout()),
|
params_D1(ref_D1.layout()),
|
||||||
ref_D1(ref_D1),
|
ref_D1(ref_D1),
|
||||||
|
batch_stride_A0(batch_stride_A0),
|
||||||
|
batch_stride_B0(batch_stride_B0),
|
||||||
|
batch_stride_B1(batch_stride_B1),
|
||||||
|
batch_stride_C1(batch_stride_C1),
|
||||||
|
batch_stride_D1(batch_stride_D1),
|
||||||
|
batch_stride_Bias0(batch_stride_Bias0),
|
||||||
|
batch_stride_Scale0(batch_stride_Scale0),
|
||||||
output_op_0(output_op_0),
|
output_op_0(output_op_0),
|
||||||
output_op_1(output_op_1) {
|
output_op_1(output_op_1) {
|
||||||
|
|
||||||
@ -247,37 +280,64 @@ struct B2bGemm {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ElementA0 *ptr_A0 = static_cast<ElementA0 *>(params.ref_A0.data());
|
||||||
|
ElementB0 *ptr_B0 = static_cast<ElementB0 *>(params.ref_B0.data());
|
||||||
|
ElementB1 *ptr_B1 = static_cast<ElementB1 *>(params.ref_B1.data());
|
||||||
|
|
||||||
|
ScaleBiasData *ptr_Bias0 = static_cast<ScaleBiasData *>(params.ref_Bias0.data());
|
||||||
|
ScaleBiasData *ptr_Scale0 = static_cast<ScaleBiasData *>(params.ref_Scale0.data());
|
||||||
|
|
||||||
|
int offset_k_0 = 0;
|
||||||
|
int offset_k_1 = 0;
|
||||||
|
|
||||||
|
int problem_size_k_0 = params.problem_size_0.k();
|
||||||
|
int problem_size_k_1 = params.problem_size_1.k();
|
||||||
|
|
||||||
|
if (params.mode == GemmUniversalMode::kGemm) {
|
||||||
|
|
||||||
|
// Problem size is a function of threadblock index in the K dimension
|
||||||
|
problem_size_k_0 = min(
|
||||||
|
problem_size_k_0,
|
||||||
|
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||||
|
|
||||||
|
// Problem size is a function of threadblock index in the K dimension
|
||||||
|
problem_size_k_1 = min(
|
||||||
|
problem_size_k_1,
|
||||||
|
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||||
|
|
||||||
|
offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0;
|
||||||
|
offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1;
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||||
|
ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0;
|
||||||
|
ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0;
|
||||||
|
ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1;
|
||||||
|
ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0;
|
||||||
|
ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0;
|
||||||
|
}
|
||||||
|
|
||||||
// Compute initial location in logical coordinates
|
// Compute initial location in logical coordinates
|
||||||
cutlass::MatrixCoord tb_offset_A0{
|
cutlass::MatrixCoord tb_offset_A0{
|
||||||
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
|
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
|
||||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
offset_k_0,
|
||||||
};
|
};
|
||||||
|
|
||||||
cutlass::MatrixCoord tb_offset_B0{
|
cutlass::MatrixCoord tb_offset_B0{
|
||||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
offset_k_0,
|
||||||
threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||||
};
|
};
|
||||||
|
|
||||||
cutlass::MatrixCoord tb_offset_B1{
|
cutlass::MatrixCoord tb_offset_B1{
|
||||||
threadblock_tile_offset.k() * params.gemm_k_size_1,
|
offset_k_1,
|
||||||
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
||||||
};
|
};
|
||||||
|
|
||||||
// Problem size is a function of threadblock index in the K dimension
|
|
||||||
int problem_size_k_0 = min(
|
|
||||||
params.problem_size_0.k(),
|
|
||||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
|
||||||
|
|
||||||
// Compute threadblock-scoped matrix multiply-add
|
// Compute threadblock-scoped matrix multiply-add
|
||||||
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
||||||
|
|
||||||
// Problem size is a function of threadblock index in the K dimension
|
|
||||||
int problem_size_k_1 = min(
|
|
||||||
params.problem_size_1.k(),
|
|
||||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
|
||||||
|
|
||||||
// Compute threadblock-scoped matrix multiply-add
|
// Compute threadblock-scoped matrix multiply-add
|
||||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||||
|
|
||||||
|
|
||||||
// Compute position within threadblock
|
// Compute position within threadblock
|
||||||
@ -286,26 +346,25 @@ struct B2bGemm {
|
|||||||
// Construct iterators to A and B operands
|
// Construct iterators to A and B operands
|
||||||
typename B2bMma::IteratorA0 iterator_A0(
|
typename B2bMma::IteratorA0 iterator_A0(
|
||||||
params.params_A0,
|
params.params_A0,
|
||||||
params.ref_A0.data(),
|
ptr_A0,
|
||||||
{params.problem_size_0.m(), problem_size_k_0},
|
{params.problem_size_0.m(), problem_size_k_0},
|
||||||
thread_idx,
|
thread_idx,
|
||||||
tb_offset_A0);
|
tb_offset_A0);
|
||||||
|
|
||||||
typename B2bMma::IteratorB0 iterator_B0(
|
typename B2bMma::IteratorB0 iterator_B0(
|
||||||
params.params_B0,
|
params.params_B0,
|
||||||
params.ref_B0.data(),
|
ptr_B0,
|
||||||
{problem_size_k_0, params.problem_size_0.n()},
|
{problem_size_k_0, params.problem_size_0.n()},
|
||||||
thread_idx,
|
thread_idx,
|
||||||
tb_offset_B0);
|
tb_offset_B0);
|
||||||
|
|
||||||
typename B2bMma::IteratorB1 iterator_B1(
|
typename B2bMma::IteratorB1 iterator_B1(
|
||||||
params.params_B1,
|
params.params_B1,
|
||||||
params.ref_B1.data(),
|
ptr_B1,
|
||||||
{problem_size_k_1, params.problem_size_1.n()},
|
{problem_size_k_1, params.problem_size_1.n()},
|
||||||
thread_idx,
|
thread_idx,
|
||||||
tb_offset_B1);
|
tb_offset_B1);
|
||||||
|
|
||||||
|
|
||||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||||
// is compiled as warp-uniform.
|
// is compiled as warp-uniform.
|
||||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||||
@ -313,7 +372,7 @@ struct B2bGemm {
|
|||||||
|
|
||||||
// Construct iterators to accumulator scale/bias vector
|
// Construct iterators to accumulator scale/bias vector
|
||||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0(
|
||||||
params.ref_Scale0.data(),
|
ptr_Scale0,
|
||||||
{1, params.problem_size_0.n()},
|
{1, params.problem_size_0.n()},
|
||||||
thread_idx,
|
thread_idx,
|
||||||
warp_idx,
|
warp_idx,
|
||||||
@ -323,7 +382,7 @@ struct B2bGemm {
|
|||||||
);
|
);
|
||||||
|
|
||||||
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0(
|
||||||
params.ref_Bias0.data(),
|
ptr_Bias0,
|
||||||
{1, params.problem_size_0.n()},
|
{1, params.problem_size_0.n()},
|
||||||
thread_idx,
|
thread_idx,
|
||||||
warp_idx,
|
warp_idx,
|
||||||
@ -349,11 +408,9 @@ struct B2bGemm {
|
|||||||
src_accum.clear();
|
src_accum.clear();
|
||||||
accumulators.clear();
|
accumulators.clear();
|
||||||
|
|
||||||
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
|
|
||||||
// Compute threadblock-scoped matrix multiply-add
|
// Compute threadblock-scoped matrix multiply-add
|
||||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0,
|
||||||
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0);
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Epilogue
|
// Epilogue
|
||||||
@ -376,23 +433,32 @@ struct B2bGemm {
|
|||||||
|
|
||||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||||
|
|
||||||
|
ElementC *ptr_C1 = static_cast<ElementC *>(params.ref_C1.data());
|
||||||
|
ElementC *ptr_D1 = static_cast<ElementC *>(params.ref_D1.data());
|
||||||
|
|
||||||
// Construct the semaphore.
|
// Construct the semaphore.
|
||||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||||
|
|
||||||
|
if (params.mode == GemmUniversalMode::kGemm) {
|
||||||
// If performing a reduction via split-K, fetch the initial synchronization
|
// If performing a reduction via split-K, fetch the initial synchronization
|
||||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
|
||||||
|
|
||||||
|
if (params.grid_tiled_shape.k() > 1) {
|
||||||
// Fetch the synchronization lock initially but do not block.
|
// Fetch the synchronization lock initially but do not block.
|
||||||
semaphore.fetch();
|
semaphore.fetch();
|
||||||
|
|
||||||
// Indicate which position in a serial reduction the output operator is currently updating
|
// Indicate which position in a serial reduction the output operator is currently updating
|
||||||
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
else if (params.mode == GemmUniversalMode::kBatched) {
|
||||||
|
ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1;
|
||||||
|
ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1;
|
||||||
|
}
|
||||||
|
|
||||||
// Tile iterator loading from source tensor.
|
// Tile iterator loading from source tensor.
|
||||||
typename Epilogue::OutputTileIterator iterator_C1(
|
typename Epilogue::OutputTileIterator iterator_C1(
|
||||||
params.params_C1,
|
params.params_C1,
|
||||||
params.ref_C1.data(),
|
ptr_C1,
|
||||||
params.problem_size_1.mn(),
|
params.problem_size_1.mn(),
|
||||||
thread_idx,
|
thread_idx,
|
||||||
threadblock_offset
|
threadblock_offset
|
||||||
@ -401,7 +467,7 @@ struct B2bGemm {
|
|||||||
// Tile iterator writing to destination tensor.
|
// Tile iterator writing to destination tensor.
|
||||||
typename Epilogue::OutputTileIterator iterator_D1(
|
typename Epilogue::OutputTileIterator iterator_D1(
|
||||||
params.params_D1,
|
params.params_D1,
|
||||||
params.ref_D1.data(),
|
ptr_D1,
|
||||||
params.problem_size_1.mn(),
|
params.problem_size_1.mn(),
|
||||||
thread_idx,
|
thread_idx,
|
||||||
threadblock_offset
|
threadblock_offset
|
||||||
@ -414,7 +480,7 @@ struct B2bGemm {
|
|||||||
lane_idx);
|
lane_idx);
|
||||||
|
|
||||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||||
|
|
||||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||||
if (threadblock_tile_offset.k()) {
|
if (threadblock_tile_offset.k()) {
|
||||||
@ -433,7 +499,7 @@ struct B2bGemm {
|
|||||||
// Release the semaphore
|
// Release the semaphore
|
||||||
//
|
//
|
||||||
|
|
||||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) {
|
||||||
|
|
||||||
int lock = 0;
|
int lock = 0;
|
||||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||||
@ -457,4 +523,3 @@ struct B2bGemm {
|
|||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace gemm
|
} // namespace gemm
|
||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
|
||||||
|
|||||||
@ -114,8 +114,6 @@ template <
|
|||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
/// Number of stages used in the pipelined mainloop
|
/// Number of stages used in the pipelined mainloop
|
||||||
int Stages,
|
int Stages,
|
||||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator,
|
typename Operator,
|
||||||
/// Stage accumulator in shared memory
|
/// Stage accumulator in shared memory
|
||||||
@ -161,16 +159,13 @@ template <
|
|||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
/// Number of stages used in the pipelined mainloop
|
/// Number of stages used in the pipelined mainloop
|
||||||
int Stages,
|
int Stages,
|
||||||
/// If true, kernel is configured to support serial reduction in the
|
|
||||||
/// epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator>
|
typename Operator>
|
||||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||||
WarpShape0, WarpShape1, InstructionShape,
|
WarpShape0, WarpShape1, InstructionShape,
|
||||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||||
Operator> {
|
Operator> {
|
||||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||||
@ -188,7 +183,7 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
|||||||
EpilogueOutputOp1::kCount>::Epilogue;
|
EpilogueOutputOp1::kCount>::Epilogue;
|
||||||
|
|
||||||
/// Define the kernel-level GEMM operator.
|
/// Define the kernel-level GEMM operator.
|
||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -228,8 +223,6 @@ template <
|
|||||||
typename EpilogueOutputOp1,
|
typename EpilogueOutputOp1,
|
||||||
/// Threadblock-level swizzling operator
|
/// Threadblock-level swizzling operator
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator
|
typename Operator
|
||||||
>
|
>
|
||||||
@ -249,7 +242,6 @@ struct DefaultB2bGemm<
|
|||||||
EpilogueOutputOp1,
|
EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
2,
|
2,
|
||||||
SplitKSerial,
|
|
||||||
Operator
|
Operator
|
||||||
> {
|
> {
|
||||||
|
|
||||||
@ -287,7 +279,7 @@ struct DefaultB2bGemm<
|
|||||||
>::Epilogue;
|
>::Epilogue;
|
||||||
|
|
||||||
/// Define the kernel-level GEMM operator.
|
/// Define the kernel-level GEMM operator.
|
||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -323,9 +315,6 @@ template <
|
|||||||
int Stages,
|
int Stages,
|
||||||
/// Number of Interleaved k
|
/// Number of Interleaved k
|
||||||
int InterleavedK,
|
int InterleavedK,
|
||||||
/// If true, kernel is configured to support serial reduction in the
|
|
||||||
/// epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator>
|
typename Operator>
|
||||||
struct DefaultB2bGemm<
|
struct DefaultB2bGemm<
|
||||||
@ -335,8 +324,7 @@ struct DefaultB2bGemm<
|
|||||||
arch::OpClassTensorOp, arch::Sm80,
|
arch::OpClassTensorOp, arch::Sm80,
|
||||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle, Stages,
|
ThreadblockSwizzle, Stages, Operator> {
|
||||||
SplitKSerial, Operator> {
|
|
||||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
@ -360,7 +348,7 @@ struct DefaultB2bGemm<
|
|||||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||||
|
|
||||||
/// Define the kernel-level GEMM operator.
|
/// Define the kernel-level GEMM operator.
|
||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -396,9 +384,6 @@ template <
|
|||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
/// Number of Interleaved k
|
/// Number of Interleaved k
|
||||||
int InterleavedK,
|
int InterleavedK,
|
||||||
/// If true, kernel is configured to support serial reduction in the
|
|
||||||
/// epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator>
|
typename Operator>
|
||||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||||
@ -408,7 +393,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
|||||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle, 2, SplitKSerial, Operator> {
|
ThreadblockSwizzle, 2, Operator> {
|
||||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
@ -430,7 +415,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
|||||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||||
|
|
||||||
/// Define the kernel-level GEMM operator.
|
/// Define the kernel-level GEMM operator.
|
||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -112,16 +112,13 @@ template <
|
|||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
/// Number of stages used in the pipelined mainloop
|
/// Number of stages used in the pipelined mainloop
|
||||||
int Stages,
|
int Stages,
|
||||||
/// If true, kernel is configured to support serial reduction in the
|
|
||||||
/// epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator>
|
typename Operator>
|
||||||
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
|
||||||
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp,
|
||||||
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
arch::Sm80, ThreadblockShape0, ThreadblockShape1,
|
||||||
WarpShape0, WarpShape1, InstructionShape,
|
WarpShape0, WarpShape1, InstructionShape,
|
||||||
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, SplitKSerial,
|
EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages,
|
||||||
Operator, true> {
|
Operator, true> {
|
||||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||||
@ -139,10 +136,9 @@ struct DefaultB2bGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignm
|
|||||||
EpilogueOutputOp1::kCount>::Epilogue;
|
EpilogueOutputOp1::kCount>::Epilogue;
|
||||||
|
|
||||||
/// Define the kernel-level GEMM operator.
|
/// Define the kernel-level GEMM operator.
|
||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Partial specialization for Turing Architecture
|
/// Partial specialization for Turing Architecture
|
||||||
@ -179,8 +175,6 @@ template <
|
|||||||
typename EpilogueOutputOp1,
|
typename EpilogueOutputOp1,
|
||||||
/// Threadblock-level swizzling operator
|
/// Threadblock-level swizzling operator
|
||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator
|
typename Operator
|
||||||
>
|
>
|
||||||
@ -200,7 +194,6 @@ struct DefaultB2bGemm<
|
|||||||
EpilogueOutputOp1,
|
EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
2,
|
2,
|
||||||
SplitKSerial,
|
|
||||||
Operator,
|
Operator,
|
||||||
true
|
true
|
||||||
> {
|
> {
|
||||||
@ -241,7 +234,7 @@ struct DefaultB2bGemm<
|
|||||||
>::Epilogue;
|
>::Epilogue;
|
||||||
|
|
||||||
/// Define the kernel-level GEMM operator.
|
/// Define the kernel-level GEMM operator.
|
||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -277,9 +270,6 @@ template <
|
|||||||
int Stages,
|
int Stages,
|
||||||
/// Number of Interleaved k
|
/// Number of Interleaved k
|
||||||
int InterleavedK,
|
int InterleavedK,
|
||||||
/// If true, kernel is configured to support serial reduction in the
|
|
||||||
/// epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator>
|
typename Operator>
|
||||||
struct DefaultB2bGemm<
|
struct DefaultB2bGemm<
|
||||||
@ -290,7 +280,7 @@ struct DefaultB2bGemm<
|
|||||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle, Stages,
|
ThreadblockSwizzle, Stages,
|
||||||
SplitKSerial, Operator, true> {
|
Operator, true> {
|
||||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
@ -314,7 +304,7 @@ struct DefaultB2bGemm<
|
|||||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||||
|
|
||||||
/// Define the kernel-level GEMM operator.
|
/// Define the kernel-level GEMM operator.
|
||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -350,9 +340,6 @@ template <
|
|||||||
typename ThreadblockSwizzle,
|
typename ThreadblockSwizzle,
|
||||||
/// Number of Interleaved k
|
/// Number of Interleaved k
|
||||||
int InterleavedK,
|
int InterleavedK,
|
||||||
/// If true, kernel is configured to support serial reduction in the
|
|
||||||
/// epilogue
|
|
||||||
bool SplitKSerial,
|
|
||||||
/// Operation performed by GEMM
|
/// Operation performed by GEMM
|
||||||
typename Operator>
|
typename Operator>
|
||||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||||
@ -362,7 +349,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
|||||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, true> {
|
ThreadblockSwizzle, 2, Operator, true> {
|
||||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||||
@ -385,7 +372,7 @@ struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
|||||||
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
64 / sizeof_bits<ElementC>::value, InterleavedK>::Epilogue;
|
||||||
|
|
||||||
/// Define the kernel-level GEMM operator.
|
/// Define the kernel-level GEMM operator.
|
||||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle>;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename TensorRefIn, ///< Input TensorRef Type
|
||||||
|
typename TensorRefOut, ///< Output TensorRef Type
|
||||||
|
typename ScalarType, ///< alpha Type
|
||||||
|
typename TensorRefScalar, ///< Scale/Bias TensorRef Type
|
||||||
|
typename ConvertOp = NumericConverter<typename TensorRefOut::Element, ScalarType>,
|
||||||
|
int kMblock = 4,
|
||||||
|
int kNblock = 4
|
||||||
|
>
|
||||||
|
__global__ void TensorScaleBiasGemmBatched(
|
||||||
|
gemm::GemmCoord problem_size,
|
||||||
|
TensorRefIn tensor_in, ///< input tensor
|
||||||
|
TensorRefOut tensor_out, ///< output tensor
|
||||||
|
ScalarType alpha, ///< alpha
|
||||||
|
TensorRefScalar tensor_scale, ///< scale tensor
|
||||||
|
TensorRefScalar tensor_bias, ///< bias tensor
|
||||||
|
int batch_count = 1,
|
||||||
|
int64_t batch_stride_tensor_in = 0,
|
||||||
|
int64_t batch_stride_tensor_out = 0,
|
||||||
|
int64_t batch_stride_tensor_scale = 0,
|
||||||
|
int64_t batch_stride_tensor_bias = 0
|
||||||
|
) {
|
||||||
|
|
||||||
|
ConvertOp convert_op;
|
||||||
|
int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock;
|
||||||
|
int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock;
|
||||||
|
int batch_idx = blockIdx.z;
|
||||||
|
|
||||||
|
tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in);
|
||||||
|
tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out);
|
||||||
|
tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale);
|
||||||
|
tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias);
|
||||||
|
|
||||||
|
for (; batch_idx < batch_count; batch_idx += gridDim.z) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < kNblock; j++) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < kMblock; i++) {
|
||||||
|
int row = row_block + i;
|
||||||
|
int col = col_block + j;
|
||||||
|
MatrixCoord coord = MatrixCoord(row, col);
|
||||||
|
if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) {
|
||||||
|
|
||||||
|
ScalarType scale = alpha;
|
||||||
|
if(tensor_scale.good())
|
||||||
|
scale = tensor_scale.at({0, coord.column()});
|
||||||
|
|
||||||
|
ScalarType bias = ScalarType(0);
|
||||||
|
|
||||||
|
if(tensor_bias.good())
|
||||||
|
bias = tensor_bias.at({0, coord.column()});
|
||||||
|
|
||||||
|
tensor_out.at(coord) = convert_op(
|
||||||
|
scale * ScalarType(tensor_in.at(coord)) + bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z);
|
||||||
|
tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z);
|
||||||
|
tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z);
|
||||||
|
tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename TensorRefIn, ///< Input TensorRef Type
|
typename TensorRefIn, ///< Input TensorRef Type
|
||||||
typename TensorRefOut, ///< Output TensorRef Type
|
typename TensorRefOut, ///< Output TensorRef Type
|
||||||
@ -217,6 +281,62 @@ void TensorScaleBiasGemm(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply scale and bias on a tensor
|
||||||
|
template <
|
||||||
|
typename ElementIn, ///< Input Type
|
||||||
|
typename ElementOut, ///< Output Type
|
||||||
|
typename Layout, ///< Layout of input/output tensor
|
||||||
|
typename ScalarType, ///< alpha Type
|
||||||
|
typename LayoutScaleBias, ///< Layout of scale and bias
|
||||||
|
typename ConvertOp = NumericConverter<ElementOut, ScalarType>
|
||||||
|
>
|
||||||
|
void TensorScaleBiasGemmBatched(
|
||||||
|
gemm::GemmCoord problem_size,
|
||||||
|
TensorRef<ElementIn, Layout> tensor_in, ///< input tensor
|
||||||
|
TensorRef<ElementOut, Layout> tensor_out, ///< output tensor
|
||||||
|
ScalarType alpha, ///< alpha
|
||||||
|
TensorRef<ScalarType, LayoutScaleBias> tensor_scale, ///< scale tensor
|
||||||
|
TensorRef<ScalarType, LayoutScaleBias> tensor_bias, ///< bias tensor
|
||||||
|
int batch_count = 1,
|
||||||
|
int64_t batch_stride_tensor_in = 0,
|
||||||
|
int64_t batch_stride_tensor_out = 0,
|
||||||
|
int64_t batch_stride_tensor_scale = 0,
|
||||||
|
int64_t batch_stride_tensor_bias = 0
|
||||||
|
) {
|
||||||
|
|
||||||
|
int const kMblock = 4;
|
||||||
|
int const kNblock = 4;
|
||||||
|
|
||||||
|
dim3 block(16, 8);
|
||||||
|
dim3 grid(
|
||||||
|
(problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock),
|
||||||
|
(problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock),
|
||||||
|
batch_count % std::numeric_limits<uint16_t>::max()
|
||||||
|
);
|
||||||
|
|
||||||
|
kernel::TensorScaleBiasGemmBatched<
|
||||||
|
TensorRef<ElementIn, Layout>,
|
||||||
|
TensorRef<ElementOut, Layout>,
|
||||||
|
ScalarType,
|
||||||
|
TensorRef<ScalarType, LayoutScaleBias>,
|
||||||
|
ConvertOp,
|
||||||
|
kMblock,
|
||||||
|
kNblock
|
||||||
|
><<< grid, block >>> (
|
||||||
|
problem_size,
|
||||||
|
tensor_in,
|
||||||
|
tensor_out,
|
||||||
|
alpha,
|
||||||
|
tensor_scale,
|
||||||
|
tensor_bias,
|
||||||
|
batch_count,
|
||||||
|
batch_stride_tensor_in,
|
||||||
|
batch_stride_tensor_out,
|
||||||
|
batch_stride_tensor_scale,
|
||||||
|
batch_stride_tensor_bias
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply scale and bias on a tensor
|
/// Apply scale and bias on a tensor
|
||||||
template <
|
template <
|
||||||
typename ElementIn, ///< Input Type
|
typename ElementIn, ///< Input Type
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user