restore the old epilogue for everything except streamk (#749)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2023-01-04 11:02:55 -05:00 committed by GitHub
parent 5989b7e1d7
commit ff6e733fe1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 306 additions and 187 deletions

View File

@ -173,8 +173,8 @@ public:
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
public:
public:
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator."); "Mismatch between shared load iterator and output tile iterator.");
@ -186,6 +186,102 @@ public:
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
public:
/// Aspect for when epilogue source is not needed
struct SourceAspectNotNeeded
{
/// Constructor
CUTLASS_DEVICE
SourceAspectNotNeeded()
{}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};
/// Aspect for when epilogue source is needed
struct SourceAspectNeeded
{
OutputTileIterator source_iterator;
typename OutputTileIterator::Fragment source_fragment;
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
static void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Constructor
CUTLASS_DEVICE
SourceAspectNeeded(OutputTileIterator source_iterator) :
source_iterator(source_iterator)
{
source_fragment.clear();
}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
{
// Load addend source fragment from global memory
source_iterator.load(source_fragment);
++source_iterator;
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
}
};
private: private:
/// Loads fragment from shared memory aligned with output tensor /// Loads fragment from shared memory aligned with output tensor
@ -268,7 +364,11 @@ public:
typename OutputTileIterator::Fragment output_fragment; typename OutputTileIterator::Fragment output_fragment;
// Apply the output operator // Apply the output operator
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); SourceAspectNeeded::apply_output_operator(
output_fragment,
output_op,
aligned_accum_fragment,
source_fragment);
// Store the final result // Store the final result
destination_iterator += reduce_fragment_idx; destination_iterator += reduce_fragment_idx;
@ -276,13 +376,45 @@ public:
} }
/// Streams the result to global memory /// Perform the epilogue computations and stream the result to global memory.
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
/// Perform the epilogue computations and stream the result to global memory. Implements
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
CUTLASS_DEVICE CUTLASS_DEVICE
void operator()( void operator()(
OutputOp const &output_op, ///< Output operator OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) OutputTileIterator source_iterator ) ///< Tile iterator for addend source
{
if (output_op.is_source_needed())
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}
else
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
}
/// Perform the epilogue computations and stream the result to global memory. Implements a
/// single codepath, regardless of whether the output op requires addend data to be loaded
CUTLASS_DEVICE
void unified(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
{ {
if (!output_op.is_source_needed()) if (!output_op.is_source_needed())
{ {
@ -290,11 +422,19 @@ public:
__syncthreads(); // Dummy (CUDA 11.0) __syncthreads(); // Dummy (CUDA 11.0)
} }
// Source-fragment data (zero-initialized for scenarios where the operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
// output operator allows us to skip loading it from global input) }
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
/// Streams the result to global memory
template <typename SourceAspect>
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
SourceAspect source)
{
// Iterator over warp-level accumulator fragment // Iterator over warp-level accumulator fragment
AccumulatorFragmentIterator accum_fragment_iterator(accumulators); AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
@ -341,13 +481,8 @@ public:
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
{ {
// Load addend source fragment from global memory typename SharedLoadIterator::Fragment aligned_accum_fragment;
source_iterator.load(source_fragment); shared_load_iterator_.load(aligned_accum_fragment);
++source_iterator;
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) if (p < Base::kFragmentsPerIteration - 1)
{ {
@ -359,9 +494,10 @@ public:
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) { for ( int i = 1; i < kPartitionsK; ++i) {
typename SharedLoadIterator::Fragment aligned_accum_fragment_addend;
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]); shared_load_iterator_.load(aligned_accum_fragment_addend);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend);
} }
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
@ -372,7 +508,7 @@ public:
// //
typename OutputTileIterator::Fragment output_fragment; typename OutputTileIterator::Fragment output_fragment;
apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0], source_fragment); source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment);
// //
// Store the final result // Store the final result
@ -388,37 +524,6 @@ public:
} }
} }
private:
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op, ///< Output operator
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -147,6 +147,101 @@ public:
OutputTileIterator::kElementsPerAccess), OutputTileIterator::kElementsPerAccess),
"Divisibility"); "Divisibility");
public:
/// Aspect for when epilogue source is not needed
struct SourceAspectNotNeeded
{
/// Constructor
CUTLASS_DEVICE
SourceAspectNotNeeded()
{}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};
/// Aspect for when epilogue source is needed
struct SourceAspectNeeded
{
OutputTileIterator source_iterator;
typename OutputTileIterator::Fragment source_fragment;
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
static void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Constructor
CUTLASS_DEVICE
SourceAspectNeeded(OutputTileIterator source_iterator) :
source_iterator(source_iterator)
{
source_fragment.clear();
}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
{
// Load addend source fragment from global memory
source_iterator.load(source_fragment);
++source_iterator;
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
}
};
/// Shared storage allocation needed by the epilogue /// Shared storage allocation needed by the epilogue
struct SharedStorage {}; struct SharedStorage {};
@ -196,7 +291,7 @@ public:
typename OutputTileIterator::Fragment output_fragment; typename OutputTileIterator::Fragment output_fragment;
// Apply the output operator // Apply the output operator
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); SourceAspectNeeded::apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
// Store the final result // Store the final result
destination_iterator += reduce_fragment_idx; destination_iterator += reduce_fragment_idx;
@ -204,85 +299,65 @@ public:
} }
/// Streams the result to global memory /// Perform the epilogue computations and stream the result to global memory.
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
/// Perform the epilogue computations and stream the result to global memory. Implements
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
CUTLASS_DEVICE CUTLASS_DEVICE
void operator()( void operator()(
OutputOp const &output_op, ///< Output operator OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) OutputTileIterator source_iterator ) ///< Tile iterator for addend source
if (!output_op.is_source_needed()) { {
compute_source_not_needed_(output_op, destination_iterator, accumulators); if (output_op.is_source_needed())
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
} }
else { else
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); {
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
} }
} }
/// Streams the result to global memory
/// Perform the epilogue computations and stream the result to global memory. Implements a
/// single codepath, regardless of whether the output op requires addend data to be loaded
CUTLASS_DEVICE CUTLASS_DEVICE
void compute_source_not_needed_( void unified(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile
) {
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Convert fragment
//
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed(output_fragment, output_op, accum_fragment);
//
// Store the final result
//
destination_iterator.set_iteration_index(iter);
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) OutputTileIterator source_iterator ) ///< Tile iterator for addend source
) { {
if (!output_op.is_source_needed())
{
source_iterator.clear_mask();
__syncthreads(); // Dummy (CUDA 11.0)
}
// operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
// Predicated tile iterators constructed from members }
//
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
/// Streams the result to global memory
template <typename SourceAspect>
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
SourceAspect source)
{
// //
// Iterator over warp-level accumulator fragment // Iterator over warp-level accumulator fragment
// //
@ -295,13 +370,6 @@ public:
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Load the source
//
source_iterator.set_iteration_index(iter);
source_iterator.load(source_fragment);
++source_iterator;
// //
// Convert fragment // Convert fragment
@ -317,7 +385,7 @@ public:
// //
typename OutputTileIterator::Fragment output_fragment; typename OutputTileIterator::Fragment output_fragment;
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); source.apply_output_operator(output_fragment, output_op, accum_fragment);
// //
// Store the final result // Store the final result
@ -328,60 +396,6 @@ public:
++destination_iterator; ++destination_iterator;
} }
} }
protected:
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(
&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(
&aligned_accum_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -551,7 +551,7 @@ public:
static Status can_implement( static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size) cutlass::gemm::GemmCoord const & problem_size)
{ {
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");
static int const kAlignmentA = (platform::is_same<LayoutA, static int const kAlignmentA = (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<32>>::value) layout::ColumnMajorInterleaved<32>>::value)
@ -851,7 +851,7 @@ protected:
threadblock_item_begin); threadblock_item_begin);
// Execute the epilogue operator to update the destination tensor. // Execute the epilogue operator to update the destination tensor.
epilogue( epilogue.unified(
EpilogueOutputOp(params.output_op), EpilogueOutputOp(params.output_op),
iterator_D, iterator_D,
accumulator_tile, accumulator_tile,