From 095cbba57c691568e293d29e3a598124efde0b12 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Tue, 22 Mar 2022 09:27:34 -0700 Subject: [PATCH] Example 23 - Passing correct alpha and beta values with --parallel-split-k (#424) When split-k is enabled, we should set alpha to 1 and beta to 0 for the split-k gemm kernel. The fix was from hwu36. I only did fixed some minor typos along with his fix. --- .../ampere_gemm_operand_reduction_fusion.cu | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu index 8ce91437..a31d9838 100644 --- a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu +++ b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu @@ -62,7 +62,7 @@ using ElementAccumulator = float; // Data type of accumulator using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation using ElementInputA = cutlass::half_t; // Data type of elements in input tensor using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -using ElementOutput = cutlass::half_t; // Data type of elements in output tensor +using ElementOutput = cutlass::half_t; // Data type of elements in output tensor using LayoutInputA = cutlass::layout::ColumnMajor; using LayoutInputB = cutlass::layout::RowMajor; @@ -122,7 +122,7 @@ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction< using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -// Below is the reduction kernel used in the case of parallel spiit-k +// Below is the reduction kernel used in the case of parallel split-k using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;; using ReduceOp = cutlass::reduction::thread::ReduceAdd< @@ -282,7 +282,7 @@ struct Options { << " --tag String to replicate across the first column in the results table\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/28_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n"; + << "$ ./examples/23_ampere_gemm_operand_reduction_fusion/23_ampere_gemm_operand_reduction_fusion --m=1024 --n=1024 --k=1024 \n\n"; return out; } @@ -398,8 +398,10 @@ Result profile(Options const &options) { tensor_reduction.sync_device(); // Initialize alpha for dot product computation - ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha); - ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta); + ElementComputeEpilogue alpha = options.parallel_split_k ? ElementComputeEpilogue(1) + : ElementComputeEpilogue(options.alpha); + ElementComputeEpilogue beta = options.parallel_split_k ? ElementComputeEpilogue(0) + : ElementComputeEpilogue(options.beta); cutlass::gemm::GemmUniversalMode mode = options.parallel_split_k ? cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel : @@ -417,8 +419,8 @@ Result profile(Options const &options) { tensor_a.device_ref().data(), // <- reference to tensor A on device tensor_b.device_ref().data(), // <- reference to tensor B on device tensor_c.device_ref().data(), // <- reference to matrix C on device - tensor_d.device_ref().data(), // <- reference to matrix C on device - tensor_reduction.device_ref().data(), // <- reference to tensor B on device + tensor_d.device_ref().data(), // <- reference to matrix D on device + tensor_reduction.device_ref().data(), // <- reference to reduction tensor on device options.problem_size.m() * options.problem_size.k(), options.problem_size.n() * options.problem_size.k(), options.problem_size.m() * options.problem_size.n(), @@ -455,6 +457,10 @@ Result profile(Options const &options) { if (options.parallel_split_k && batch_count > 1) { // reduce gemm + + ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha); + ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta); + int splitk_gemm_stride = options.problem_size.m(); cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride); @@ -531,10 +537,10 @@ Result profile(Options const &options) { gemm_device ( options.problem_size, - alpha, + ElementComputeEpilogue(options.alpha), tensor_a.device_ref(), tensor_b.device_ref(), - beta, + ElementComputeEpilogue(options.beta), tensor_c.device_ref(), tensor_ref_d.device_ref() );