Example 23 - Passing correct alpha and beta values with --parallel-split-k (#424)

When split-k is enabled, we should set alpha to 1 and beta to 0 for the
split-k gemm kernel.

The fix was from hwu36. I only did fixed some minor typos along with his
fix.
This commit is contained in:
Yang Chen 2022-03-22 09:27:34 -07:00 committed by GitHub
parent 8f1fe7a132
commit 095cbba57c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -62,7 +62,7 @@ using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::RowMajor;
@ -122,7 +122,7 @@ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Below is the reduction kernel used in the case of parallel spiit-k
// Below is the reduction kernel used in the case of parallel split-k
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
using ReduceOp = cutlass::reduction::thread::ReduceAdd<
@ -282,7 +282,7 @@ struct Options {
<< " --tag <string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/28_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n";
<< "$ ./examples/23_ampere_gemm_operand_reduction_fusion/23_ampere_gemm_operand_reduction_fusion --m=1024 --n=1024 --k=1024 \n\n";
return out;
}
@ -398,8 +398,10 @@ Result profile(Options const &options) {
tensor_reduction.sync_device();
// Initialize alpha for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha);
ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta);
ElementComputeEpilogue alpha = options.parallel_split_k ? ElementComputeEpilogue(1)
: ElementComputeEpilogue(options.alpha);
ElementComputeEpilogue beta = options.parallel_split_k ? ElementComputeEpilogue(0)
: ElementComputeEpilogue(options.beta);
cutlass::gemm::GemmUniversalMode mode = options.parallel_split_k ?
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel :
@ -417,8 +419,8 @@ Result profile(Options const &options) {
tensor_a.device_ref().data(), // <- reference to tensor A on device
tensor_b.device_ref().data(), // <- reference to tensor B on device
tensor_c.device_ref().data(), // <- reference to matrix C on device
tensor_d.device_ref().data(), // <- reference to matrix C on device
tensor_reduction.device_ref().data(), // <- reference to tensor B on device
tensor_d.device_ref().data(), // <- reference to matrix D on device
tensor_reduction.device_ref().data(), // <- reference to reduction tensor on device
options.problem_size.m() * options.problem_size.k(),
options.problem_size.n() * options.problem_size.k(),
options.problem_size.m() * options.problem_size.n(),
@ -455,6 +457,10 @@ Result profile(Options const &options) {
if (options.parallel_split_k && batch_count > 1) {
// reduce gemm
ElementComputeEpilogue alpha = ElementComputeEpilogue(options.alpha);
ElementComputeEpilogue beta = ElementComputeEpilogue(options.beta);
int splitk_gemm_stride = options.problem_size.m();
cutlass::layout::RowMajor splitk_gemm_layout(splitk_gemm_stride);
@ -531,10 +537,10 @@ Result profile(Options const &options) {
gemm_device
(
options.problem_size,
alpha,
ElementComputeEpilogue(options.alpha),
tensor_a.device_ref(),
tensor_b.device_ref(),
beta,
ElementComputeEpilogue(options.beta),
tensor_c.device_ref(),
tensor_ref_d.device_ref()
);