From ff6e733fe117b3201c8c984729aa6905cca9e20d Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Wed, 4 Jan 2023 11:02:55 -0500 Subject: [PATCH] restore the old epilogue for everything except streamk (#749) Co-authored-by: Haicheng Wu --- .../cutlass/epilogue/threadblock/epilogue.h | 203 ++++++++++--- .../threadblock/interleaved_epilogue.h | 286 +++++++++--------- .../gemm/kernel/gemm_universal_streamk.h | 4 +- 3 files changed, 306 insertions(+), 187 deletions(-) diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index 993de892..f7c9c825 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -173,8 +173,8 @@ public: static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; -public: +public: static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, "Mismatch between shared load iterator and output tile iterator."); @@ -186,6 +186,102 @@ public: static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); + +public: + + /// Aspect for when epilogue source is not needed + struct SourceAspectNotNeeded + { + /// Constructor + CUTLASS_DEVICE + SourceAspectNotNeeded() + {} + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment) + { + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) + { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i]); + } + } + }; + + + /// Aspect for when epilogue source is needed + struct SourceAspectNeeded + { + OutputTileIterator source_iterator; + + typename OutputTileIterator::Fragment source_fragment; + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + static void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment, + typename OutputTileIterator::Fragment const &source_fragment) + { + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + OutputAccessType const *source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) + { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); + } + } + + /// Constructor + CUTLASS_DEVICE + SourceAspectNeeded(OutputTileIterator source_iterator) : + source_iterator(source_iterator) + { + source_fragment.clear(); + } + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment) + { + // Load addend source fragment from global memory + source_iterator.load(source_fragment); + ++source_iterator; + + apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); + } + }; + + private: /// Loads fragment from shared memory aligned with output tensor @@ -268,7 +364,11 @@ public: typename OutputTileIterator::Fragment output_fragment; // Apply the output operator - apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); + SourceAspectNeeded::apply_output_operator( + output_fragment, + output_op, + aligned_accum_fragment, + source_fragment); // Store the final result destination_iterator += reduce_fragment_idx; @@ -276,13 +376,45 @@ public: } - /// Streams the result to global memory + /// Perform the epilogue computations and stream the result to global memory. + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); + } + + + /// Perform the epilogue computations and stream the result to global memory. Implements + /// two alternative codepaths, depending on whether the output op requires addend data to be loaded. CUTLASS_DEVICE void operator()( OutputOp const &output_op, ///< Output operator OutputTileIterator destination_iterator, ///< Tile iterator for destination AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator ) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + OutputTileIterator source_iterator ) ///< Tile iterator for addend source + { + if (output_op.is_source_needed()) + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); + } + else + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); + } + } + + + /// Perform the epilogue computations and stream the result to global memory. Implements a + /// single codepath, regardless of whether the output op requires addend data to be loaded + CUTLASS_DEVICE + void unified( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator ) ///< Tile iterator for addend source { if (!output_op.is_source_needed()) { @@ -290,11 +422,19 @@ public: __syncthreads(); // Dummy (CUDA 11.0) } - // Source-fragment data (zero-initialized for scenarios where the - // output operator allows us to skip loading it from global input) - typename OutputTileIterator::Fragment source_fragment; - source_fragment.clear(); + operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); + } + + /// Streams the result to global memory + template + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + SourceAspect source) + { // Iterator over warp-level accumulator fragment AccumulatorFragmentIterator accum_fragment_iterator(accumulators); @@ -341,13 +481,8 @@ public: CUTLASS_PRAGMA_UNROLL for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { - // Load addend source fragment from global memory - source_iterator.load(source_fragment); - ++source_iterator; - - typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; - - shared_load_iterator_.load(aligned_accum_fragment[0]); + typename SharedLoadIterator::Fragment aligned_accum_fragment; + shared_load_iterator_.load(aligned_accum_fragment); if (p < Base::kFragmentsPerIteration - 1) { @@ -359,9 +494,10 @@ public: CUTLASS_PRAGMA_UNROLL for ( int i = 1; i < kPartitionsK; ++i) { + typename SharedLoadIterator::Fragment aligned_accum_fragment_addend; shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator_.load(aligned_accum_fragment[i]); - aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + shared_load_iterator_.load(aligned_accum_fragment_addend); + aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend); } shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); @@ -372,7 +508,7 @@ public: // typename OutputTileIterator::Fragment output_fragment; - apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0], source_fragment); + source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment); // // Store the final result @@ -388,37 +524,6 @@ public: } } -private: - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator( - typename OutputTileIterator::Fragment &output_fragment, - OutputOp const &output_op, ///< Output operator - typename SharedLoadIterator::Fragment const &aligned_accum_fragment, - typename OutputTileIterator::Fragment const &source_fragment) - { - - OutputAccessType *output_frag_ptr = - reinterpret_cast(&output_fragment); - - AccumulatorAccessType const *compute_frag_ptr = - reinterpret_cast(&aligned_accum_fragment); - - OutputAccessType const *source_frag_ptr = - reinterpret_cast(&source_fragment); - - int const kOutputOpIterations = - OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) - { - // Call the output operator - output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); - } - } - }; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h index 84d469e9..bfcdc295 100644 --- a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h +++ b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h @@ -147,6 +147,101 @@ public: OutputTileIterator::kElementsPerAccess), "Divisibility"); +public: + + /// Aspect for when epilogue source is not needed + struct SourceAspectNotNeeded + { + /// Constructor + CUTLASS_DEVICE + SourceAspectNotNeeded() + {} + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) + { + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) + { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i]); + } + } + }; + + + /// Aspect for when epilogue source is needed + struct SourceAspectNeeded + { + OutputTileIterator source_iterator; + + typename OutputTileIterator::Fragment source_fragment; + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + static void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment, + typename OutputTileIterator::Fragment const &source_fragment) + { + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + OutputAccessType const *source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) + { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); + } + } + + /// Constructor + CUTLASS_DEVICE + SourceAspectNeeded(OutputTileIterator source_iterator) : + source_iterator(source_iterator) + { + source_fragment.clear(); + } + + /// Invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) + { + // Load addend source fragment from global memory + source_iterator.load(source_fragment); + ++source_iterator; + + apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); + } + }; + + /// Shared storage allocation needed by the epilogue struct SharedStorage {}; @@ -196,7 +291,7 @@ public: typename OutputTileIterator::Fragment output_fragment; // Apply the output operator - apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); + SourceAspectNeeded::apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); // Store the final result destination_iterator += reduce_fragment_idx; @@ -204,29 +299,65 @@ public: } - /// Streams the result to global memory + /// Perform the epilogue computations and stream the result to global memory. CUTLASS_DEVICE void operator()( - OutputOp const &output_op, ///< Output operator - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - if (!output_op.is_source_needed()) { - compute_source_not_needed_(output_op, destination_iterator, accumulators); + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); + } + + + /// Perform the epilogue computations and stream the result to global memory. Implements + /// two alternative codepaths, depending on whether the output op requires addend data to be loaded. + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator ) ///< Tile iterator for addend source + { + if (output_op.is_source_needed()) + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); } - else { - compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); + else + { + operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded()); } } - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_not_needed_( - OutputOp const &output_op, ///< Output operator - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile - ) { + + /// Perform the epilogue computations and stream the result to global memory. Implements a + /// single codepath, regardless of whether the output op requires addend data to be loaded + CUTLASS_DEVICE + void unified( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator ) ///< Tile iterator for addend source + { + if (!output_op.is_source_needed()) + { + source_iterator.clear_mask(); + __syncthreads(); // Dummy (CUDA 11.0) + } + + operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator)); + } + + + /// Streams the result to global memory + template + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + SourceAspect source) + { // // Iterator over warp-level accumulator fragment // @@ -254,7 +385,7 @@ public: // typename OutputTileIterator::Fragment output_fragment; - apply_output_operator_source_not_needed(output_fragment, output_op, accum_fragment); + source.apply_output_operator(output_fragment, output_op, accum_fragment); // // Store the final result @@ -264,123 +395,6 @@ public: destination_iterator.store(output_fragment); ++destination_iterator; } - } - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_needed_( - OutputOp const &output_op, ///< Output operator - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - ) { - - // - // Predicated tile iterators constructed from members - // - - typename OutputTileIterator::Fragment source_fragment; - - source_fragment.clear(); - - // - // Iterator over warp-level accumulator fragment - // - - AccumulatorFragmentIterator accum_fragment_iterator(accumulators); - - // - // Iterate over accumulator tile - // - - CUTLASS_PRAGMA_UNROLL - for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - // - // Load the source - // - - source_iterator.set_iteration_index(iter); - source_iterator.load(source_fragment); - ++source_iterator; - - // - // Convert fragment - // - - typename AccumulatorFragmentIterator::Fragment accum_fragment; - - accum_fragment_iterator.load(accum_fragment); - ++accum_fragment_iterator; - - // - // Compute the output result - // - - typename OutputTileIterator::Fragment output_fragment; - apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); - - // - // Store the final result - // - - destination_iterator.set_iteration_index(iter); - destination_iterator.store(output_fragment); - ++destination_iterator; - } - } - -protected: - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator( - typename OutputTileIterator::Fragment &output_fragment, - OutputOp const &output_op, - typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment, - typename OutputTileIterator::Fragment const &source_fragment) - { - OutputAccessType *output_frag_ptr = - reinterpret_cast(&output_fragment); - - AccumulatorAccessType const *compute_frag_ptr = - reinterpret_cast( - &aligned_accum_fragment); - - OutputAccessType const *source_frag_ptr = - reinterpret_cast(&source_fragment); - - int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / - OutputTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - // Call the output operator - output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); - } - } - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_source_not_needed( - typename OutputTileIterator::Fragment &output_fragment, - OutputOp const &output_op, - typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) - { - OutputAccessType *output_frag_ptr = - reinterpret_cast(&output_fragment); - - AccumulatorAccessType const *compute_frag_ptr = - reinterpret_cast( - &aligned_accum_fragment); - - int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / - OutputTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - // Call the output operator - output_frag_ptr[i] = output_op(compute_frag_ptr[i]); - } } }; diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index 831b5a5f..a354ee01 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -551,7 +551,7 @@ public: static Status can_implement( cutlass::gemm::GemmCoord const & problem_size) { - CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); + CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()"); static int const kAlignmentA = (platform::is_same>::value) @@ -851,7 +851,7 @@ protected: threadblock_item_begin); // Execute the epilogue operator to update the destination tensor. - epilogue( + epilogue.unified( EpilogueOutputOp(params.output_op), iterator_D, accumulator_tile,