diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h index b693adac..8a8a976a 100644 --- a/tools/library/include/cutlass/library/handle.h +++ b/tools/library/include/cutlass/library/handle.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -338,6 +338,9 @@ using HandlePtr = std::unique_ptr; /// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation); ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace +Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library } // namespace cutlass diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 24f91256..37e2b89f 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -590,7 +590,8 @@ public: void const *configuration) const = 0; virtual uint64_t get_device_workspace_size( - void const *configuration) const = 0; + void const *configuration, + void const *arguments = nullptr) const = 0; virtual Status initialize( void const *configuration, diff --git a/tools/library/src/conv2d_operation.h b/tools/library/src/conv2d_operation.h index a6ff4472..9f7fc89a 100644 --- a/tools/library/src/conv2d_operation.h +++ b/tools/library/src/conv2d_operation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -272,7 +272,8 @@ public: /// Gets the device-side workspace virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { OperatorArguments args; diff --git a/tools/library/src/conv3d_operation.h b/tools/library/src/conv3d_operation.h index 951a3dc1..704cd3ca 100644 --- a/tools/library/src/conv3d_operation.h +++ b/tools/library/src/conv3d_operation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -266,7 +266,8 @@ public: /// Gets the device-side workspace virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { OperatorArguments args; diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index 1cc1e8b5..31f76e9b 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -242,7 +242,8 @@ public: /// Gets the device-side workspace virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { OperatorArguments args; @@ -443,7 +444,8 @@ public: /// Gets the device-side workspace virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { OperatorArguments args; @@ -569,7 +571,7 @@ protected: operator_args.ldb = (configuration->ldb); operator_args.ldc = (configuration->ldc); operator_args.ldd = (configuration->ldd); - + return Status::kSuccess; } @@ -649,7 +651,8 @@ public: /// Gets the device-side workspace virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { + void const *configuration_ptr, + void const *arguments_ptr) const { OperatorArguments args; @@ -661,6 +664,14 @@ public: return 0; } + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + uint64_t size = Operator::get_workspace_size(args); return size; @@ -855,7 +866,8 @@ public: /// Gets the device-side workspace virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { OperatorArguments args; @@ -1055,7 +1067,8 @@ public: /// Gets the device-side workspace virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { OperatorArguments args; diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index 3fb085ca..ea073c7b 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -1098,6 +1098,59 @@ Operation const* find_conv_operation_for_parallel_reduction(Operation const *ope return nullptr; } + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Finds gemm operation instances with Gemm::ElementC = Reduction::ElementWorkspace +Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation) { + + GemmDescription const &gemm_desc = + static_cast(operation->description()); + + // if the curren gemm operation accumulator and output data type match return operation + if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.C.element) { + return operation; + } + + // find gemm operation to match gemm output and reduction workspace data type + GemmFunctionalKey key( + library::Provider::kCUTLASS, + gemm_desc.gemm_kind, + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + gemm_desc.A.element, + gemm_desc.A.layout, + gemm_desc.transform_A, + gemm_desc.B.element, + gemm_desc.B.layout, + gemm_desc.transform_B, + gemm_desc.tile_description.math_instruction.element_accumulator); + + // gemm operation table + auto gemm_operations = Singleton::get().operation_table.gemm_operations; + + // find ConvFunctionalKey in gemm operation table + auto operators_it = gemm_operations.find(key); + + if (operators_it == gemm_operations.end()) { + return nullptr; + } + + if (operators_it->second.empty()) { + return nullptr; + } + + // A and B uses the same alignment in the generator.py + int alignment = gemm_desc.A.alignment; + + // gemm operation for same compute capability and iterator algorithm + GemmPreferenceKey preference_key( + gemm_desc.tile_description.minimum_compute_capability, + alignment); + + return find_gemm_operation(operators_it, preference_key); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/src/reduction/init_reduction_operations.cu b/tools/library/src/reduction/init_reduction_operations.cu index 9e85034e..9b49a1f1 100644 --- a/tools/library/src/reduction/init_reduction_operations.cu +++ b/tools/library/src/reduction/init_reduction_operations.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: diff --git a/tools/library/src/reduction/reduction_device.cu b/tools/library/src/reduction/reduction_device.cu index ddd056a2..5c3776c0 100644 --- a/tools/library/src/reduction/reduction_device.cu +++ b/tools/library/src/reduction/reduction_device.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -46,7 +46,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) { using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, - 128 / cutlass::sizeof_bits::value, + 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >; @@ -81,7 +81,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) { using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, - 128 / cutlass::sizeof_bits::value, + 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >; @@ -115,7 +115,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest) using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, - 128 / cutlass::sizeof_bits::value, + 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementCompute >; @@ -140,6 +140,5 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest) )); } - } } diff --git a/tools/library/src/reduction/reduction_operation.h b/tools/library/src/reduction/reduction_operation.h index 8dafff3b..03996c91 100644 --- a/tools/library/src/reduction/reduction_operation.h +++ b/tools/library/src/reduction/reduction_operation.h @@ -30,6 +30,7 @@ #include #include "cutlass/cutlass.h" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" #include "cutlass/reduction/thread/reduction_operators.h" #include "cutlass/reduction/device/reduce_split_k.h" @@ -180,7 +181,8 @@ public: /// Gets the device-side workspace virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { OperatorArguments args; diff --git a/tools/library/src/reference/conv_reference_operation.h b/tools/library/src/reference/conv_reference_operation.h index 8d7d10a2..9837ed07 100644 --- a/tools/library/src/reference/conv_reference_operation.h +++ b/tools/library/src/reference/conv_reference_operation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -403,7 +403,8 @@ public: } virtual uint64_t get_device_workspace_size( - void const *configuration) const { + void const *configuration, + void const *arguments = nullptr) const { return 0; } diff --git a/tools/library/src/reference/gemm_reference_operation.h b/tools/library/src/reference/gemm_reference_operation.h index 422ac352..3acf5817 100644 --- a/tools/library/src/reference/gemm_reference_operation.h +++ b/tools/library/src/reference/gemm_reference_operation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -161,7 +161,8 @@ public: } virtual uint64_t get_device_workspace_size( - void const *configuration) const { + void const *configuration, + void const *arguments = nullptr) const { return 0; } diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 7afea996..298785b2 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -37,6 +37,7 @@ #include "gemm_operation_profiler.h" #include "gpu_timer.h" +#include "cutlass/library/singleton.h" #include "cutlass/library/library.h" #include "cutlass/library/handle.h" @@ -55,6 +56,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options): library::OperationKind::kGemm, { {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"}, + {ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"}, {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, @@ -100,6 +102,9 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const { << "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" << " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n" + << "Profile a particular problem size with split K and paralell reduction:\n" + << " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n" + << "Using various input value distribution:\n" << " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n" << " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n" @@ -155,8 +160,17 @@ Status GemmOperationProfiler::GemmProblem::parse( // default value this->k = 1024; } + + if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) { + // defualt value + this->split_k_mode = library::SplitKMode::kSerial; + } this->mode = library::GemmUniversalMode::kGemm; + if(this->split_k_mode == library::SplitKMode::kParallel) { + this->mode = library::GemmUniversalMode::kGemmSplitKParallel; + } + if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { // default value this->split_k_slices = 1; @@ -165,8 +179,7 @@ Status GemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { // default value this->batch_count = 1; - } - else if (this->batch_count > 1) { + } else if (this->batch_count > 1) { this->mode = library::GemmUniversalMode::kBatched; } @@ -210,7 +223,7 @@ Status GemmOperationProfiler::GemmProblem::parse( return Status::kErrorInternal; } } - + this->lda = DeviceAllocation::get_packed_layout( operation_desc.A.layout, {int(this->m), int(this->k)}).front(); @@ -279,6 +292,8 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); + set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); + set_argument(result, "A", problem_space, std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); @@ -321,7 +336,7 @@ Status GemmOperationProfiler::initialize_configuration( } Status status = problem_.parse(operation_desc, problem_space, problem); - + if (status != Status::kSuccess) { return status; } @@ -350,6 +365,13 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_.arguments.beta = problem_.beta.data(); gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + // initialize reduction operation for parallel splitKMode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + if (!initialize_reduction_configuration_(operation, problem)) { + return Status::kErrorInternal; + } + } + initialize_result_(this->model_result_, options, operation_desc, problem_space); return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); @@ -366,7 +388,7 @@ void GemmOperationProfiler::initialize_result_( result.disposition = Disposition::kNotRun; result.status = Status::kSuccess; result.operation_name = operation_desc.name; - + problem_.initialize_result(result, operation_desc, problem_space); OperationProfiler::initialize_result_(result, operation_desc, problem_space); @@ -377,6 +399,51 @@ void GemmOperationProfiler::initialize_result_( } +/// Initialize redution problem dimentions and library::Operation +bool GemmOperationProfiler::initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem) { + library::GemmDescription const &gemm_desc = + static_cast(operation->description()); + + if (!cast_from_double(problem_.alpha_one, gemm_desc.element_epilogue, 1)) { + return false; + } + + if (!cast_from_double(problem_.beta_zero, gemm_desc.element_epilogue, 0)) { + return false; + } + + /// initialize library::ReductionConfiguration + gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn(); + gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); + gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product(); + gemm_workspace_.reduction_configuration.ldw = problem_.ldc; + gemm_workspace_.reduction_configuration.lds = problem_.ldc; + gemm_workspace_.reduction_configuration.ldd = problem_.ldc; + + // find reduction operation + library::ReductionFunctionalKey reduction_key( + library::Provider::kCUTLASS, + gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace + gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator + gemm_desc.C.element, // element output + gemm_desc.element_epilogue // element coumpute + ); + + auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key); + + if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) { + return false; + } + + // initialize reduction operation required for parallel split-k operator + reduction_op_ = reduction_it->second; + + // reduction operation found and initialized + return true; +} + /// Initializes workspace Status GemmOperationProfiler::initialize_workspace( Options const &options, @@ -385,7 +452,15 @@ Status GemmOperationProfiler::initialize_workspace( library::Operation const *operation, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - + + library::Operation const* underlying_operation = operation; + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { + return Status::kErrorNotSupported; + } + } + library::GemmDescription const &operation_desc = static_cast(operation->description()); @@ -455,8 +530,12 @@ Status GemmOperationProfiler::initialize_workspace( ); gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data()); - } + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); + gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + } // // Initialize the CUTLASS operation @@ -467,16 +546,35 @@ Status GemmOperationProfiler::initialize_workspace( if (options.execution_mode != ExecutionMode::kDryRun) { - uint64_t workspace_size = operation->get_host_workspace_size(&gemm_workspace_.configuration); + uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration); gemm_workspace_.host_workspace.resize(workspace_size, 0); - workspace_size = operation->get_device_workspace_size(&gemm_workspace_.configuration); + workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, + &gemm_workspace_.arguments); gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); - status = operation->initialize( + status = underlying_operation->initialize( &gemm_workspace_.configuration, gemm_workspace_.host_workspace.data(), gemm_workspace_.device_workspace.data()); + + if (status != Status::kSuccess) { + return status; + } + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration); + gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0); + + status = reduction_op_->initialize( + &gemm_workspace_.reduction_configuration, + gemm_workspace_.reduction_host_workspace.data(), + nullptr); + + if (status != Status::kSuccess) { + return status; + } + } } // @@ -527,11 +625,34 @@ bool GemmOperationProfiler::verify_cutlass( gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); + gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } + // // Run the CUTLASS operation // - results_.back().status = operation->run( + // initialize gemm underlying operation to handle parallel reduction + library::Operation const * underlying_operation = operation; + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { + results_.back().disposition = Disposition::kFailed; + return false; + } + } + + results_.back().status = underlying_operation->run( &gemm_workspace_.arguments, gemm_workspace_.host_workspace.data(), gemm_workspace_.device_workspace.data()); @@ -541,6 +662,19 @@ bool GemmOperationProfiler::verify_cutlass( return false; } + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + results_.back().status = reduction_op_->run( + &gemm_workspace_.reduction_arguments, + gemm_workspace_.reduction_host_workspace.data(), + nullptr); + + if (results_.back().status != Status::kSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + } + cudaError_t result = cudaDeviceSynchronize(); if (result != cudaSuccess) { results_.back().disposition = Disposition::kFailed; @@ -896,6 +1030,19 @@ bool GemmOperationProfiler::profile( gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); + gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } + results_.back().status = profile_cutlass_( results_.back().runtime, options, @@ -921,6 +1068,15 @@ Status GemmOperationProfiler::profile_cutlass_( GpuTimer timer; + // initialize gemm underlying operation to handle parallel reduction + library::Operation const * underlying_operation = operation; + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) { + return Status::kErrorNotSupported; + } + } + // // Optional sleep to limit power consumption and thermals // @@ -942,8 +1098,16 @@ Status GemmOperationProfiler::profile_cutlass_( gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); + } + // Execute the CUTLASS operation - status = operation->run( + status = underlying_operation->run( &gemm_workspace_.arguments, host_workspace, device_workspace); @@ -951,6 +1115,18 @@ Status GemmOperationProfiler::profile_cutlass_( if (status != Status::kSuccess) { return status; } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + status = reduction_op_->run( + &gemm_workspace_.reduction_arguments, + gemm_workspace_.reduction_host_workspace.data(), + nullptr); + + if (status != Status::kSuccess) { + return status; + } + } } // @@ -977,7 +1153,15 @@ Status GemmOperationProfiler::profile_cutlass_( gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); - status = operation->run( + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); + } + + status = underlying_operation->run( arguments, host_workspace, device_workspace); @@ -985,6 +1169,18 @@ Status GemmOperationProfiler::profile_cutlass_( if (status != Status::kSuccess) { return status; } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + status = reduction_op_->run( + &gemm_workspace_.reduction_arguments, + gemm_workspace_.reduction_host_workspace.data(), + nullptr); + + if (status != Status::kSuccess) { + return status; + } + } } // diff --git a/tools/profiler/src/gemm_operation_profiler.h b/tools/profiler/src/gemm_operation_profiler.h index 4db81f94..0f3096f7 100644 --- a/tools/profiler/src/gemm_operation_profiler.h +++ b/tools/profiler/src/gemm_operation_profiler.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -45,6 +45,7 @@ #include "operation_profiler.h" #include "performance_result.h" #include "problem_space.h" +#include "reduction_operation_profiler.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -61,6 +62,7 @@ public: struct GemmProblem { cutlass::library::GemmUniversalMode mode; + cutlass::library::SplitKMode split_k_mode; int64_t m; int64_t n; int64_t k; @@ -72,6 +74,12 @@ public: int split_k_slices; int batch_count; + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + // // Methods // @@ -121,6 +129,13 @@ public: /// Buffer used for the operations' device workspace DeviceAllocation device_workspace; + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + // // Methods // @@ -141,6 +156,8 @@ protected: /// Device memory allocations GemmWorkspace gemm_workspace_; + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; public: // @@ -231,6 +248,10 @@ protected: void *host_workspace, void *device_workspace); + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); }; /////////////////////////////////////////////////////////////////////////////////////////////////