restore the old epilogue for everything except streamk (#749)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
5989b7e1d7
commit
ff6e733fe1
@ -173,8 +173,8 @@ public:
|
||||
|
||||
static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
|
||||
public:
|
||||
|
||||
public:
|
||||
|
||||
static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
|
||||
"Mismatch between shared load iterator and output tile iterator.");
|
||||
@ -186,6 +186,102 @@ public:
|
||||
|
||||
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1.");
|
||||
|
||||
|
||||
public:
|
||||
|
||||
/// Aspect for when epilogue source is not needed
|
||||
struct SourceAspectNotNeeded
|
||||
{
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SourceAspectNotNeeded()
|
||||
{}
|
||||
|
||||
/// Invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
|
||||
{
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||
{
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Aspect for when epilogue source is needed
|
||||
struct SourceAspectNeeded
|
||||
{
|
||||
OutputTileIterator source_iterator;
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
/// Invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
static void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
|
||||
typename OutputTileIterator::Fragment const &source_fragment)
|
||||
{
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
OutputAccessType const *source_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||
{
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SourceAspectNeeded(OutputTileIterator source_iterator) :
|
||||
source_iterator(source_iterator)
|
||||
{
|
||||
source_fragment.clear();
|
||||
}
|
||||
|
||||
/// Invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment)
|
||||
{
|
||||
// Load addend source fragment from global memory
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
|
||||
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
private:
|
||||
|
||||
/// Loads fragment from shared memory aligned with output tensor
|
||||
@ -268,7 +364,11 @@ public:
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
// Apply the output operator
|
||||
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
||||
SourceAspectNeeded::apply_output_operator(
|
||||
output_fragment,
|
||||
output_op,
|
||||
aligned_accum_fragment,
|
||||
source_fragment);
|
||||
|
||||
// Store the final result
|
||||
destination_iterator += reduce_fragment_idx;
|
||||
@ -276,13 +376,45 @@ public:
|
||||
}
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
/// Perform the epilogue computations and stream the result to global memory.
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
|
||||
{
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
||||
}
|
||||
|
||||
|
||||
/// Perform the epilogue computations and stream the result to global memory. Implements
|
||||
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator ) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
||||
{
|
||||
if (output_op.is_source_needed())
|
||||
{
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||
}
|
||||
else
|
||||
{
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Perform the epilogue computations and stream the result to global memory. Implements a
|
||||
/// single codepath, regardless of whether the output op requires addend data to be loaded
|
||||
CUTLASS_DEVICE
|
||||
void unified(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
||||
{
|
||||
if (!output_op.is_source_needed())
|
||||
{
|
||||
@ -290,11 +422,19 @@ public:
|
||||
__syncthreads(); // Dummy (CUDA 11.0)
|
||||
}
|
||||
|
||||
// Source-fragment data (zero-initialized for scenarios where the
|
||||
// output operator allows us to skip loading it from global input)
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
source_fragment.clear();
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||
}
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
template <typename SourceAspect>
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
SourceAspect source)
|
||||
{
|
||||
// Iterator over warp-level accumulator fragment
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
@ -341,13 +481,8 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
|
||||
{
|
||||
// Load addend source fragment from global memory
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment;
|
||||
shared_load_iterator_.load(aligned_accum_fragment);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1)
|
||||
{
|
||||
@ -359,9 +494,10 @@ public:
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment_addend;
|
||||
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
shared_load_iterator_.load(aligned_accum_fragment_addend);
|
||||
aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
|
||||
@ -372,7 +508,7 @@ public:
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
|
||||
source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
@ -388,37 +524,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
|
||||
typename OutputTileIterator::Fragment const &source_fragment)
|
||||
{
|
||||
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
OutputAccessType const *source_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||
{
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -147,6 +147,101 @@ public:
|
||||
OutputTileIterator::kElementsPerAccess),
|
||||
"Divisibility");
|
||||
|
||||
public:
|
||||
|
||||
/// Aspect for when epilogue source is not needed
|
||||
struct SourceAspectNotNeeded
|
||||
{
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SourceAspectNotNeeded()
|
||||
{}
|
||||
|
||||
/// Invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
|
||||
{
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||
{
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Aspect for when epilogue source is needed
|
||||
struct SourceAspectNeeded
|
||||
{
|
||||
OutputTileIterator source_iterator;
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
/// Invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
static void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
|
||||
typename OutputTileIterator::Fragment const &source_fragment)
|
||||
{
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
OutputAccessType const *source_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i)
|
||||
{
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SourceAspectNeeded(OutputTileIterator source_iterator) :
|
||||
source_iterator(source_iterator)
|
||||
{
|
||||
source_fragment.clear();
|
||||
}
|
||||
|
||||
/// Invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
|
||||
{
|
||||
// Load addend source fragment from global memory
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
|
||||
apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Shared storage allocation needed by the epilogue
|
||||
struct SharedStorage {};
|
||||
|
||||
@ -196,7 +291,7 @@ public:
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
// Apply the output operator
|
||||
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
|
||||
SourceAspectNeeded::apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
|
||||
|
||||
// Store the final result
|
||||
destination_iterator += reduce_fragment_idx;
|
||||
@ -204,29 +299,65 @@ public:
|
||||
}
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
/// Perform the epilogue computations and stream the result to global memory.
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
if (!output_op.is_source_needed()) {
|
||||
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators) ///< Complete warp-level accumulator tile
|
||||
{
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
||||
}
|
||||
|
||||
|
||||
/// Perform the epilogue computations and stream the result to global memory. Implements
|
||||
/// two alternative codepaths, depending on whether the output op requires addend data to be loaded.
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
||||
{
|
||||
if (output_op.is_source_needed())
|
||||
{
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||
}
|
||||
else {
|
||||
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
|
||||
else
|
||||
{
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNotNeeded());
|
||||
}
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile
|
||||
) {
|
||||
|
||||
/// Perform the epilogue computations and stream the result to global memory. Implements a
|
||||
/// single codepath, regardless of whether the output op requires addend data to be loaded
|
||||
CUTLASS_DEVICE
|
||||
void unified(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator ) ///< Tile iterator for addend source
|
||||
{
|
||||
if (!output_op.is_source_needed())
|
||||
{
|
||||
source_iterator.clear_mask();
|
||||
__syncthreads(); // Dummy (CUDA 11.0)
|
||||
}
|
||||
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||
}
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
template <typename SourceAspect>
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
SourceAspect source)
|
||||
{
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
@ -254,7 +385,7 @@ public:
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
apply_output_operator_source_not_needed(output_fragment, output_op, accum_fragment);
|
||||
source.apply_output_operator(output_fragment, output_op, accum_fragment);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
@ -265,123 +396,6 @@ public:
|
||||
++destination_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
) {
|
||||
|
||||
//
|
||||
// Predicated tile iterators constructed from members
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
source_fragment.clear();
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
//
|
||||
// Load the source
|
||||
//
|
||||
|
||||
source_iterator.set_iteration_index(iter);
|
||||
source_iterator.load(source_fragment);
|
||||
++source_iterator;
|
||||
|
||||
//
|
||||
// Convert fragment
|
||||
//
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
//
|
||||
// Compute the output result
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment);
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
//
|
||||
|
||||
destination_iterator.set_iteration_index(iter);
|
||||
destination_iterator.store(output_fragment);
|
||||
++destination_iterator;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment,
|
||||
typename OutputTileIterator::Fragment const &source_fragment)
|
||||
{
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(
|
||||
&aligned_accum_fragment);
|
||||
|
||||
OutputAccessType const *source_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType const *>(&source_fragment);
|
||||
|
||||
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
|
||||
OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op,
|
||||
typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment)
|
||||
{
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(
|
||||
&aligned_accum_fragment);
|
||||
|
||||
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
|
||||
OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -551,7 +551,7 @@ public:
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::can_implement()");
|
||||
CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()");
|
||||
|
||||
static int const kAlignmentA = (platform::is_same<LayoutA,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
@ -851,7 +851,7 @@ protected:
|
||||
threadblock_item_begin);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(
|
||||
epilogue.unified(
|
||||
EpilogueOutputOp(params.output_op),
|
||||
iterator_D,
|
||||
accumulator_tile,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user