restore the old epilogue for everything except streamk (#749)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2023-01-04 11:02:55 -05:00 committed by GitHub
parent 5989b7e1d7
commit ff6e733fe1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 306 additions and 187 deletions

View File

@ -173,8 +173,8 @@ public:
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
public:
public:
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
@ -186,6 +186,102 @@ public:
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
public:
/// Aspect for when epilogue source is not needed
struct SourceAspectNotNeeded
{
/// Constructor
CUTLASS_DEVICE
SourceAspectNotNeeded()
{}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};
/// Aspect for when epilogue source is needed
struct SourceAspectNeeded
{
OutputTileIterator source_iterator;
typename OutputTileIterator::Fragment source_fragment;
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
static void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Constructor
CUTLASS_DEVICE
SourceAspectNeeded(OutputTileIterator source_iterator) :
source_iterator(source_iterator)
{
source_fragment.clear();
}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
{
// Load addend source fragment from global memory
source_iterator.load(source_fragment);
++source_iterator;
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
}
};
private:
/// Loads fragment from shared memory aligned with output tensor
@ -268,7 +364,11 @@ public:
typename OutputTileIterator::Fragment output_fragment;
// Apply the output operator
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
SourceAspectNeeded::apply_output_operator(
output_fragment,
output_op,
aligned_accum_fragment,
source_fragment);
// Store the final result
destination_iterator += reduce_fragment_idx;
@ -276,13 +376,45 @@ public:
}
/// Streams the result to global memory
/// Perform the epilogue computations and stream the result to global memory.
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
/// Perform the epilogue computations and stream the result to global memory. Implements
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
{
if (output_op.is_source_needed())
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}
else
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
}
/// Perform the epilogue computations and stream the result to global memory. Implements a
/// single codepath, regardless of whether the output op requires addend data to be loaded
CUTLASS_DEVICE
void unified(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
{
if (!output_op.is_source_needed())
{
@ -290,11 +422,19 @@ public:
__syncthreads(); // Dummy (CUDA 11.0)
}
// Source-fragment data (zero-initialized for scenarios where the
// output operator allows us to skip loading it from global input)
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}
/// Streams the result to global memory
template <typename SourceAspect>
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
SourceAspect source)
{
// Iterator over warp-level accumulator fragment
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
@ -341,13 +481,8 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
{
// Load addend source fragment from global memory
source_iterator.load(source_fragment);
++source_iterator;
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
typename SharedLoadIterator::Fragment aligned_accum_fragment;
shared_load_iterator_.load(aligned_accum_fragment);
if (p < Base::kFragmentsPerIteration - 1)
{
@ -359,9 +494,10 @@ public:
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
typename SharedLoadIterator::Fragment aligned_accum_fragment_addend;
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
shared_load_iterator_.load(aligned_accum_fragment_addend);
aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
@ -372,7 +508,7 @@ public:
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment);
//
// Store the final result
@ -388,37 +524,6 @@ public:
}
}
private:
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op, ///< Output operator
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -147,6 +147,101 @@ public:
OutputTileIterator::kElementsPerAccess),
"Divisibility");
public:
/// Aspect for when epilogue source is not needed
struct SourceAspectNotNeeded
{
/// Constructor
CUTLASS_DEVICE
SourceAspectNotNeeded()
{}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};
/// Aspect for when epilogue source is needed
struct SourceAspectNeeded
{
OutputTileIterator source_iterator;
typename OutputTileIterator::Fragment source_fragment;
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
static void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i)
{
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Constructor
CUTLASS_DEVICE
SourceAspectNeeded(OutputTileIterator source_iterator) :
source_iterator(source_iterator)
{
source_fragment.clear();
}
/// Invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
{
// Load addend source fragment from global memory
source_iterator.load(source_fragment);
++source_iterator;
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
}
};
/// Shared storage allocation needed by the epilogue
struct SharedStorage {};
@ -196,7 +291,7 @@ public:
typename OutputTileIterator::Fragment output_fragment;
// Apply the output operator
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
SourceAspectNeeded::apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
// Store the final result
destination_iterator += reduce_fragment_idx;
@ -204,29 +299,65 @@ public:
}
/// Streams the result to global memory
/// Perform the epilogue computations and stream the result to global memory.
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
if (!output_op.is_source_needed()) {
compute_source_not_needed_(output_op, destination_iterator, accumulators);
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
/// Perform the epilogue computations and stream the result to global memory. Implements
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
{
if (output_op.is_source_needed())
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}
else {
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
else
{
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
}
}
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile
) {
/// Perform the epilogue computations and stream the result to global memory. Implements a
/// single codepath, regardless of whether the output op requires addend data to be loaded
CUTLASS_DEVICE
void unified(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
{
if (!output_op.is_source_needed())
{
source_iterator.clear_mask();
__syncthreads(); // Dummy (CUDA 11.0)
}
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}
/// Streams the result to global memory
template <typename SourceAspect>
CUTLASS_DEVICE
void operator()(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
SourceAspect source)
{
//
// Iterator over warp-level accumulator fragment
//
@ -254,7 +385,7 @@ public:
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed(output_fragment, output_op, accum_fragment);
source.apply_output_operator(output_fragment, output_op, accum_fragment);
//
// Store the final result
@ -264,123 +395,6 @@ public:
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
) {
//
// Predicated tile iterators constructed from members
//
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Load the source
//
source_iterator.set_iteration_index(iter);
source_iterator.load(source_fragment);
++source_iterator;
//
// Convert fragment
//
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
//
// Store the final result
//
destination_iterator.set_iteration_index(iter);
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
protected:
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
typename OutputTileIterator::Fragment const &source_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(
&aligned_accum_fragment);
OutputAccessType const *source_frag_ptr =
reinterpret_cast<OutputAccessType const *>(&source_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op,
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
{
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(
&aligned_accum_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};

View File

@ -551,7 +551,7 @@ public:
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size)
{
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");
static int const kAlignmentA = (platform::is_same<LayoutA,
layout::ColumnMajorInterleaved<32>>::value)
@ -851,7 +851,7 @@ protected:
threadblock_item_begin);
// Execute the epilogue operator to update the destination tensor.
epilogue(
epilogue.unified(
EpilogueOutputOp(params.output_op),
iterator_D,
accumulator_tile,