Add conversion from ElementBias to ElementCompute (#961)
This commit is contained in:
parent
6f47420213
commit
7dbf423763
@ -475,6 +475,12 @@ public:
|
|||||||
Tensor tRS_rT_frg = recast<typename ThreadEpilogueOp::FragmentT>(tRS_rT);
|
Tensor tRS_rT_frg = recast<typename ThreadEpilogueOp::FragmentT>(tRS_rT);
|
||||||
Tensor tRS_rBias_frg = recast<typename ThreadEpilogueOp::FragmentBias>(tRS_rBias);
|
Tensor tRS_rBias_frg = recast<typename ThreadEpilogueOp::FragmentBias>(tRS_rBias);
|
||||||
|
|
||||||
|
// thread::LinearCombinationBiasElementwise expects that the bias passed in is of
|
||||||
|
// type ElementCompute. Therefore, conversion from type ElementBias to ElementCompute
|
||||||
|
// is needed before calling the thread-level epilogue.
|
||||||
|
cutlass::NumericArrayConverter<ElementCompute, ElementBias,
|
||||||
|
ThreadEpilogueOp::FragmentBias::kElements> bias_converter;
|
||||||
|
|
||||||
// Partition for smem to register copy (tSR_)
|
// Partition for smem to register copy (tSR_)
|
||||||
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R,InternalElementC>{}, tiled_r2s);
|
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R,InternalElementC>{}, tiled_r2s);
|
||||||
ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx);
|
ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx);
|
||||||
@ -538,13 +544,15 @@ public:
|
|||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(tRS_rD_frg); ++i) {
|
for (int i = 0; i < size(tRS_rD_frg); ++i) {
|
||||||
epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i), tRS_rBias_frg(i));
|
typename ThreadEpilogueOp::FragmentCompute converted_bias = bias_converter(tRS_rBias_frg(i));
|
||||||
|
epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i), converted_bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(tRS_rD_frg); ++i) {
|
for (int i = 0; i < size(tRS_rD_frg); ++i) {
|
||||||
epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rBias_frg(i));
|
typename ThreadEpilogueOp::FragmentCompute converted_bias = bias_converter(tRS_rBias_frg(i));
|
||||||
|
epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), converted_bias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ public:
|
|||||||
using FragmentSource = FragmentC;
|
using FragmentSource = FragmentC;
|
||||||
using FragmentOutput = FragmentZ;
|
using FragmentOutput = FragmentZ;
|
||||||
using ElementBias = ElementVector;
|
using ElementBias = ElementVector;
|
||||||
using FragmentBias = FragmentCompute;
|
using FragmentBias = Array<ElementBias, kElementsPerAccess>;
|
||||||
using ActivationFunctor = ElementwiseOp;
|
using ActivationFunctor = ElementwiseOp;
|
||||||
static const ScaleType::Kind kScale = ScaleType::Default;
|
static const ScaleType::Kind kScale = ScaleType::Default;
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_ReLU) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) {
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
@ -144,7 +144,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_GELU) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_GELU) {
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
@ -188,7 +188,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_ReLU_NoStoreT) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU_NoStoreT) {
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
@ -231,7 +231,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_Negate) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_Negate) {
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
@ -275,7 +275,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU) {
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
@ -319,7 +319,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU) {
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
@ -363,7 +363,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU_VoidC) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU_VoidC) {
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
@ -407,4 +407,92 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF16Mul_ReLU_VoidC) {
|
||||||
|
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||||
|
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||||
|
|
||||||
|
static constexpr bool StoreT = true;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise<
|
||||||
|
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, cutlass::half_t>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
float, float,
|
||||||
|
void, LayoutC, 8,
|
||||||
|
cutlass::half_t, LayoutC, 8,
|
||||||
|
EpilogueSchedule
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::half_t, LayoutA, 8,
|
||||||
|
cutlass::half_t, LayoutB, 8,
|
||||||
|
float,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
Shape<int,int,int,int>,
|
||||||
|
CollectiveMainloop,
|
||||||
|
CollectiveEpilogue
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
|
||||||
|
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
|
||||||
|
EXPECT_TRUE(passed);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8Mul_ReLU_VoidC) {
|
||||||
|
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||||
|
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||||
|
|
||||||
|
static constexpr bool StoreT = true;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise<
|
||||||
|
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, int8_t>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
float, float,
|
||||||
|
void, LayoutC, 8,
|
||||||
|
cutlass::half_t, LayoutC, 8,
|
||||||
|
EpilogueSchedule
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::half_t, LayoutA, 8,
|
||||||
|
cutlass::half_t, LayoutB, 8,
|
||||||
|
float,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
Shape<int,int,int,int>,
|
||||||
|
CollectiveMainloop,
|
||||||
|
CollectiveEpilogue
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
|
||||||
|
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
|
||||||
|
EXPECT_TRUE(passed);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
@ -100,7 +100,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_ReLU) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_ReLU) {
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
@ -143,7 +143,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_GELU) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_GELU) {
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
@ -187,7 +187,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_ReLU_NoStoreT) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_ReLU_NoStoreT) {
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
@ -230,7 +230,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_Negate) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_Negate) {
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
@ -274,7 +274,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU) {
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
@ -318,7 +318,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU) {
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
@ -362,7 +362,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU_VoidC) {
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU_VoidC) {
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
@ -406,4 +406,92 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128
|
|||||||
EXPECT_TRUE(passed);
|
EXPECT_TRUE(passed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF16Mul_ReLU_VoidC) {
|
||||||
|
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||||
|
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||||
|
|
||||||
|
static constexpr bool StoreT = true;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise<
|
||||||
|
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, cutlass::half_t>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
float, float,
|
||||||
|
void, LayoutC, 8,
|
||||||
|
cutlass::half_t, LayoutC, 8,
|
||||||
|
EpilogueSchedule
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::half_t, LayoutA, 8,
|
||||||
|
cutlass::half_t, LayoutB, 8,
|
||||||
|
float,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpong
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
Shape<int,int,int,int>,
|
||||||
|
CollectiveMainloop,
|
||||||
|
CollectiveEpilogue
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
|
||||||
|
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
|
||||||
|
EXPECT_TRUE(passed);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasS8Mul_ReLU_VoidC) {
|
||||||
|
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||||
|
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||||
|
|
||||||
|
static constexpr bool StoreT = true;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise<
|
||||||
|
cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, int8_t>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||||
|
float, float,
|
||||||
|
void, LayoutC, 8,
|
||||||
|
cutlass::half_t, LayoutC, 8,
|
||||||
|
EpilogueSchedule
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::half_t, LayoutA, 8,
|
||||||
|
cutlass::half_t, LayoutB, 8,
|
||||||
|
float,
|
||||||
|
TileShape_MNK, ClusterShape_MNK,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||||
|
cutlass::gemm::KernelTmaWarpSpecializedPingpong
|
||||||
|
>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
Shape<int,int,int,int>,
|
||||||
|
CollectiveMainloop,
|
||||||
|
CollectiveEpilogue
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
|
||||||
|
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
|
||||||
|
EXPECT_TRUE(passed);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
Loading…
Reference in New Issue
Block a user