From 7dbf42376330230b9c5f0fe2a0ac1c167d1f1889 Mon Sep 17 00:00:00 2001 From: Jack Kosaian Date: Fri, 26 May 2023 23:08:36 -0400 Subject: [PATCH] Add conversion from ElementBias to ElementCompute (#961) --- ...e_tma_warpspecialized_bias_elementwise.hpp | 12 +- .../linear_combination_bias_elementwise.h | 2 +- ...pecialized_cooperative_bias_elementwise.cu | 104 ++++++++++++++++-- ...rpspecialized_pingpong_bias_elementwise.cu | 104 ++++++++++++++++-- 4 files changed, 203 insertions(+), 19 deletions(-) diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp index b8eea8e1..4a7978b2 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -475,6 +475,12 @@ public: Tensor tRS_rT_frg = recast(tRS_rT); Tensor tRS_rBias_frg = recast(tRS_rBias); + // thread::LinearCombinationBiasElementwise expects that the bias passed in is of + // type ElementCompute. Therefore, conversion from type ElementBias to ElementCompute + // is needed before calling the thread-level epilogue. + cutlass::NumericArrayConverter bias_converter; + // Partition for smem to register copy (tSR_) TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_r2s); ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); @@ -538,13 +544,15 @@ public: CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tRS_rD_frg); ++i) { - epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i), tRS_rBias_frg(i)); + typename ThreadEpilogueOp::FragmentCompute converted_bias = bias_converter(tRS_rBias_frg(i)); + epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i), converted_bias); } } else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tRS_rD_frg); ++i) { - epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rBias_frg(i)); + typename ThreadEpilogueOp::FragmentCompute converted_bias = bias_converter(tRS_rBias_frg(i)); + epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), converted_bias); } } diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 416198c7..7970b5f7 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -95,7 +95,7 @@ public: using FragmentSource = FragmentC; using FragmentOutput = FragmentZ; using ElementBias = ElementVector; - using FragmentBias = FragmentCompute; + using FragmentBias = Array; using ActivationFunctor = ElementwiseOp; static const ScaleType::Kind kScale = ScaleType::Default; diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu index d07187bd..d95b14a2 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu @@ -101,7 +101,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; @@ -144,7 +144,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_GELU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_GELU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; @@ -188,7 +188,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_ReLU_NoStoreT) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU_NoStoreT) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; @@ -231,7 +231,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_Negate) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_Negate) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -275,7 +275,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -319,7 +319,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -363,7 +363,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU_VoidC) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -407,4 +407,92 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF16Mul_ReLU_VoidC) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, cutlass::half_t>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8Mul_ReLU_VoidC) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, int8_t>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) \ No newline at end of file diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu index 8ca22703..16a063af 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu @@ -100,7 +100,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; @@ -143,7 +143,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_GELU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_GELU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; @@ -187,7 +187,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_ReLU_NoStoreT) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_ReLU_NoStoreT) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; @@ -230,7 +230,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_Negate) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_Negate) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -274,7 +274,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -318,7 +318,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -362,7 +362,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU_VoidC) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -406,4 +406,92 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF16Mul_ReLU_VoidC) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, cutlass::half_t>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasS8Mul_ReLU_VoidC) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, int8_t>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) \ No newline at end of file