restore the old epilogue for everything except streamk (#749)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
5989b7e1d7
commit
ff6e733fe1
@ -173,8 +173,8 @@ public:
|
|||||||
|
|
||||||
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
||||||
"Mismatch between shared load iterator and output tile iterator.");
|
"Mismatch between shared load iterator and output tile iterator.");
|
||||||
@ -186,6 +186,102 @@ public:
|
|||||||
|
|
||||||
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
||||||
|
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Aspect for when epilogue source is not needed
|
||||||
|
struct SourceAspectNotNeeded
|
||||||
|
{
|
||||||
|
/// Constructor
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
SourceAspectNotNeeded()
|
||||||
|
{}
|
||||||
|
|
||||||
|
/// Invoke the output functor over each vector of output
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void apply_output_operator(
|
||||||
|
typename OutputTileIterator::Fragment &output_fragment,
|
||||||
|
OutputOp const &output_op,
|
||||||
|
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
|
||||||
|
{
|
||||||
|
OutputAccessType *output_frag_ptr =
|
||||||
|
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||||
|
|
||||||
|
AccumulatorAccessType const *compute_frag_ptr =
|
||||||
|
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||||
|
|
||||||
|
int const kOutputOpIterations =
|
||||||
|
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||||
|
{
|
||||||
|
// Call the output operator
|
||||||
|
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Aspect for when epilogue source is needed
|
||||||
|
struct SourceAspectNeeded
|
||||||
|
{
|
||||||
|
OutputTileIterator source_iterator;
|
||||||
|
|
||||||
|
typename OutputTileIterator::Fragment source_fragment;
|
||||||
|
|
||||||
|
/// Invoke the output functor over each vector of output
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void apply_output_operator(
|
||||||
|
typename OutputTileIterator::Fragment &output_fragment,
|
||||||
|
OutputOp const &output_op,
|
||||||
|
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
|
||||||
|
typename OutputTileIterator::Fragment const &source_fragment)
|
||||||
|
{
|
||||||
|
OutputAccessType *output_frag_ptr =
|
||||||
|
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||||
|
|
||||||
|
AccumulatorAccessType const *compute_frag_ptr =
|
||||||
|
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||||
|
|
||||||
|
OutputAccessType const *source_frag_ptr =
|
||||||
|
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
||||||
|
|
||||||
|
int const kOutputOpIterations =
|
||||||
|
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||||
|
{
|
||||||
|
// Call the output operator
|
||||||
|
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructor
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
SourceAspectNeeded(OutputTileIterator source_iterator) :
|
||||||
|
source_iterator(source_iterator)
|
||||||
|
{
|
||||||
|
source_fragment.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invoke the output functor over each vector of output
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void apply_output_operator(
|
||||||
|
typename OutputTileIterator::Fragment &output_fragment,
|
||||||
|
OutputOp const &output_op,
|
||||||
|
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
|
||||||
|
{
|
||||||
|
// Load addend source fragment from global memory
|
||||||
|
source_iterator.load(source_fragment);
|
||||||
|
++source_iterator;
|
||||||
|
|
||||||
|
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
/// Loads fragment from shared memory aligned with output tensor
|
/// Loads fragment from shared memory aligned with output tensor
|
||||||
@ -268,7 +364,11 @@ public:
|
|||||||
typename OutputTileIterator::Fragment output_fragment;
|
typename OutputTileIterator::Fragment output_fragment;
|
||||||
|
|
||||||
// Apply the output operator
|
// Apply the output operator
|
||||||
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
SourceAspectNeeded::apply_output_operator(
|
||||||
|
output_fragment,
|
||||||
|
output_op,
|
||||||
|
aligned_accum_fragment,
|
||||||
|
source_fragment);
|
||||||
|
|
||||||
// Store the final result
|
// Store the final result
|
||||||
destination_iterator += reduce_fragment_idx;
|
destination_iterator += reduce_fragment_idx;
|
||||||
@ -276,13 +376,45 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Streams the result to global memory
|
/// Perform the epilogue computations and stream the result to global memory.
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void operator()(
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
|
||||||
|
{
|
||||||
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Perform the epilogue computations and stream the result to global memory. Implements
|
||||||
|
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
OutputOp const &output_op, ///< Output operator
|
OutputOp const &output_op, ///< Output operator
|
||||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||||
OutputTileIterator source_iterator ) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
||||||
|
{
|
||||||
|
if (output_op.is_source_needed())
|
||||||
|
{
|
||||||
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Perform the epilogue computations and stream the result to global memory. Implements a
|
||||||
|
/// single codepath, regardless of whether the output op requires addend data to be loaded
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void unified(
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||||
|
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
||||||
{
|
{
|
||||||
if (!output_op.is_source_needed())
|
if (!output_op.is_source_needed())
|
||||||
{
|
{
|
||||||
@ -290,11 +422,19 @@ public:
|
|||||||
__syncthreads(); // Dummy (CUDA 11.0)
|
__syncthreads(); // Dummy (CUDA 11.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Source-fragment data (zero-initialized for scenarios where the
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||||
// output operator allows us to skip loading it from global input)
|
}
|
||||||
typename OutputTileIterator::Fragment source_fragment;
|
|
||||||
source_fragment.clear();
|
|
||||||
|
|
||||||
|
|
||||||
|
/// Streams the result to global memory
|
||||||
|
template <typename SourceAspect>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void operator()(
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||||
|
SourceAspect source)
|
||||||
|
{
|
||||||
// Iterator over warp-level accumulator fragment
|
// Iterator over warp-level accumulator fragment
|
||||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||||
|
|
||||||
@ -341,13 +481,8 @@ public:
|
|||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
|
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
|
||||||
{
|
{
|
||||||
// Load addend source fragment from global memory
|
typename SharedLoadIterator::Fragment aligned_accum_fragment;
|
||||||
source_iterator.load(source_fragment);
|
shared_load_iterator_.load(aligned_accum_fragment);
|
||||||
++source_iterator;
|
|
||||||
|
|
||||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
|
||||||
|
|
||||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
|
||||||
|
|
||||||
if (p < Base::kFragmentsPerIteration - 1)
|
if (p < Base::kFragmentsPerIteration - 1)
|
||||||
{
|
{
|
||||||
@ -359,9 +494,10 @@ public:
|
|||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||||
|
typename SharedLoadIterator::Fragment aligned_accum_fragment_addend;
|
||||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
shared_load_iterator_.load(aligned_accum_fragment_addend);
|
||||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||||
@ -372,7 +508,7 @@ public:
|
|||||||
//
|
//
|
||||||
|
|
||||||
typename OutputTileIterator::Fragment output_fragment;
|
typename OutputTileIterator::Fragment output_fragment;
|
||||||
apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
|
source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Store the final result
|
// Store the final result
|
||||||
@ -388,37 +524,6 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
|
|
||||||
/// Helper to invoke the output functor over each vector of output
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void apply_output_operator(
|
|
||||||
typename OutputTileIterator::Fragment &output_fragment,
|
|
||||||
OutputOp const &output_op, ///< Output operator
|
|
||||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
|
|
||||||
typename OutputTileIterator::Fragment const &source_fragment)
|
|
||||||
{
|
|
||||||
|
|
||||||
OutputAccessType *output_frag_ptr =
|
|
||||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
|
||||||
|
|
||||||
AccumulatorAccessType const *compute_frag_ptr =
|
|
||||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
|
||||||
|
|
||||||
OutputAccessType const *source_frag_ptr =
|
|
||||||
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
|
||||||
|
|
||||||
int const kOutputOpIterations =
|
|
||||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < kOutputOpIterations; ++i)
|
|
||||||
{
|
|
||||||
// Call the output operator
|
|
||||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -147,6 +147,101 @@ public:
|
|||||||
OutputTileIterator::kElementsPerAccess),
|
OutputTileIterator::kElementsPerAccess),
|
||||||
"Divisibility");
|
"Divisibility");
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Aspect for when epilogue source is not needed
|
||||||
|
struct SourceAspectNotNeeded
|
||||||
|
{
|
||||||
|
/// Constructor
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
SourceAspectNotNeeded()
|
||||||
|
{}
|
||||||
|
|
||||||
|
/// Invoke the output functor over each vector of output
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void apply_output_operator(
|
||||||
|
typename OutputTileIterator::Fragment &output_fragment,
|
||||||
|
OutputOp const &output_op,
|
||||||
|
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
|
||||||
|
{
|
||||||
|
OutputAccessType *output_frag_ptr =
|
||||||
|
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||||
|
|
||||||
|
AccumulatorAccessType const *compute_frag_ptr =
|
||||||
|
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||||
|
|
||||||
|
int const kOutputOpIterations =
|
||||||
|
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||||
|
{
|
||||||
|
// Call the output operator
|
||||||
|
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Aspect for when epilogue source is needed
|
||||||
|
struct SourceAspectNeeded
|
||||||
|
{
|
||||||
|
OutputTileIterator source_iterator;
|
||||||
|
|
||||||
|
typename OutputTileIterator::Fragment source_fragment;
|
||||||
|
|
||||||
|
/// Invoke the output functor over each vector of output
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void apply_output_operator(
|
||||||
|
typename OutputTileIterator::Fragment &output_fragment,
|
||||||
|
OutputOp const &output_op,
|
||||||
|
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
|
||||||
|
typename OutputTileIterator::Fragment const &source_fragment)
|
||||||
|
{
|
||||||
|
OutputAccessType *output_frag_ptr =
|
||||||
|
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||||
|
|
||||||
|
AccumulatorAccessType const *compute_frag_ptr =
|
||||||
|
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||||
|
|
||||||
|
OutputAccessType const *source_frag_ptr =
|
||||||
|
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
||||||
|
|
||||||
|
int const kOutputOpIterations =
|
||||||
|
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||||
|
{
|
||||||
|
// Call the output operator
|
||||||
|
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructor
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
SourceAspectNeeded(OutputTileIterator source_iterator) :
|
||||||
|
source_iterator(source_iterator)
|
||||||
|
{
|
||||||
|
source_fragment.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invoke the output functor over each vector of output
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void apply_output_operator(
|
||||||
|
typename OutputTileIterator::Fragment &output_fragment,
|
||||||
|
OutputOp const &output_op,
|
||||||
|
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
|
||||||
|
{
|
||||||
|
// Load addend source fragment from global memory
|
||||||
|
source_iterator.load(source_fragment);
|
||||||
|
++source_iterator;
|
||||||
|
|
||||||
|
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
/// Shared storage allocation needed by the epilogue
|
/// Shared storage allocation needed by the epilogue
|
||||||
struct SharedStorage {};
|
struct SharedStorage {};
|
||||||
|
|
||||||
@ -196,7 +291,7 @@ public:
|
|||||||
typename OutputTileIterator::Fragment output_fragment;
|
typename OutputTileIterator::Fragment output_fragment;
|
||||||
|
|
||||||
// Apply the output operator
|
// Apply the output operator
|
||||||
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
|
SourceAspectNeeded::apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
|
||||||
|
|
||||||
// Store the final result
|
// Store the final result
|
||||||
destination_iterator += reduce_fragment_idx;
|
destination_iterator += reduce_fragment_idx;
|
||||||
@ -204,29 +299,65 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Streams the result to global memory
|
/// Perform the epilogue computations and stream the result to global memory.
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
OutputOp const &output_op, ///< Output operator
|
OutputOp const &output_op, ///< Output operator
|
||||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
|
||||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
{
|
||||||
if (!output_op.is_source_needed()) {
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
||||||
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Perform the epilogue computations and stream the result to global memory. Implements
|
||||||
|
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void operator()(
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||||
|
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
||||||
|
{
|
||||||
|
if (output_op.is_source_needed())
|
||||||
|
{
|
||||||
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||||
}
|
}
|
||||||
else {
|
else
|
||||||
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
|
{
|
||||||
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Streams the result to global memory
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void compute_source_not_needed_(
|
|
||||||
OutputOp const &output_op, ///< Output operator
|
|
||||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
|
||||||
AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile
|
|
||||||
) {
|
|
||||||
|
|
||||||
|
/// Perform the epilogue computations and stream the result to global memory. Implements a
|
||||||
|
/// single codepath, regardless of whether the output op requires addend data to be loaded
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void unified(
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||||
|
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
||||||
|
{
|
||||||
|
if (!output_op.is_source_needed())
|
||||||
|
{
|
||||||
|
source_iterator.clear_mask();
|
||||||
|
__syncthreads(); // Dummy (CUDA 11.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Streams the result to global memory
|
||||||
|
template <typename SourceAspect>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void operator()(
|
||||||
|
OutputOp const &output_op, ///< Output operator
|
||||||
|
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||||
|
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||||
|
SourceAspect source)
|
||||||
|
{
|
||||||
//
|
//
|
||||||
// Iterator over warp-level accumulator fragment
|
// Iterator over warp-level accumulator fragment
|
||||||
//
|
//
|
||||||
@ -254,7 +385,7 @@ public:
|
|||||||
//
|
//
|
||||||
|
|
||||||
typename OutputTileIterator::Fragment output_fragment;
|
typename OutputTileIterator::Fragment output_fragment;
|
||||||
apply_output_operator_source_not_needed(output_fragment, output_op, accum_fragment);
|
source.apply_output_operator(output_fragment, output_op, accum_fragment);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Store the final result
|
// Store the final result
|
||||||
@ -265,123 +396,6 @@ public:
|
|||||||
++destination_iterator;
|
++destination_iterator;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Streams the result to global memory
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void compute_source_needed_(
|
|
||||||
OutputOp const &output_op, ///< Output operator
|
|
||||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
|
||||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
|
||||||
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
|
||||||
) {
|
|
||||||
|
|
||||||
//
|
|
||||||
// Predicated tile iterators constructed from members
|
|
||||||
//
|
|
||||||
|
|
||||||
typename OutputTileIterator::Fragment source_fragment;
|
|
||||||
|
|
||||||
source_fragment.clear();
|
|
||||||
|
|
||||||
//
|
|
||||||
// Iterator over warp-level accumulator fragment
|
|
||||||
//
|
|
||||||
|
|
||||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Iterate over accumulator tile
|
|
||||||
//
|
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
|
||||||
//
|
|
||||||
// Load the source
|
|
||||||
//
|
|
||||||
|
|
||||||
source_iterator.set_iteration_index(iter);
|
|
||||||
source_iterator.load(source_fragment);
|
|
||||||
++source_iterator;
|
|
||||||
|
|
||||||
//
|
|
||||||
// Convert fragment
|
|
||||||
//
|
|
||||||
|
|
||||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
|
||||||
|
|
||||||
accum_fragment_iterator.load(accum_fragment);
|
|
||||||
++accum_fragment_iterator;
|
|
||||||
|
|
||||||
//
|
|
||||||
// Compute the output result
|
|
||||||
//
|
|
||||||
|
|
||||||
typename OutputTileIterator::Fragment output_fragment;
|
|
||||||
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Store the final result
|
|
||||||
//
|
|
||||||
|
|
||||||
destination_iterator.set_iteration_index(iter);
|
|
||||||
destination_iterator.store(output_fragment);
|
|
||||||
++destination_iterator;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
|
|
||||||
/// Helper to invoke the output functor over each vector of output
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void apply_output_operator(
|
|
||||||
typename OutputTileIterator::Fragment &output_fragment,
|
|
||||||
OutputOp const &output_op,
|
|
||||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
|
|
||||||
typename OutputTileIterator::Fragment const &source_fragment)
|
|
||||||
{
|
|
||||||
OutputAccessType *output_frag_ptr =
|
|
||||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
|
||||||
|
|
||||||
AccumulatorAccessType const *compute_frag_ptr =
|
|
||||||
reinterpret_cast<AccumulatorAccessType const *>(
|
|
||||||
&aligned_accum_fragment);
|
|
||||||
|
|
||||||
OutputAccessType const *source_frag_ptr =
|
|
||||||
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
|
||||||
|
|
||||||
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
|
|
||||||
OutputTileIterator::kElementsPerAccess;
|
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
|
||||||
// Call the output operator
|
|
||||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper to invoke the output functor over each vector of output
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void apply_output_operator_source_not_needed(
|
|
||||||
typename OutputTileIterator::Fragment &output_fragment,
|
|
||||||
OutputOp const &output_op,
|
|
||||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
|
|
||||||
{
|
|
||||||
OutputAccessType *output_frag_ptr =
|
|
||||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
|
||||||
|
|
||||||
AccumulatorAccessType const *compute_frag_ptr =
|
|
||||||
reinterpret_cast<AccumulatorAccessType const *>(
|
|
||||||
&aligned_accum_fragment);
|
|
||||||
|
|
||||||
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
|
|
||||||
OutputTileIterator::kElementsPerAccess;
|
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
|
||||||
// Call the output operator
|
|
||||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -551,7 +551,7 @@ public:
|
|||||||
static Status can_implement(
|
static Status can_implement(
|
||||||
cutlass::gemm::GemmCoord const & problem_size)
|
cutlass::gemm::GemmCoord const & problem_size)
|
||||||
{
|
{
|
||||||
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
|
CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");
|
||||||
|
|
||||||
static int const kAlignmentA = (platform::is_same<LayoutA,
|
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||||
layout::ColumnMajorInterleaved<32>>::value)
|
layout::ColumnMajorInterleaved<32>>::value)
|
||||||
@ -851,7 +851,7 @@ protected:
|
|||||||
threadblock_item_begin);
|
threadblock_item_begin);
|
||||||
|
|
||||||
// Execute the epilogue operator to update the destination tensor.
|
// Execute the epilogue operator to update the destination tensor.
|
||||||
epilogue(
|
epilogue.unified(
|
||||||
EpilogueOutputOp(params.output_op),
|
EpilogueOutputOp(params.output_op),
|
||||||
iterator_D,
|
iterator_D,
|
||||||
accumulator_tile,
|
accumulator_tile,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user