Add acc2smem in epilogue/threadblock/epilogue.h (#806)

This commit is contained in:
Jack Kosaian 2023-02-06 22:04:16 -05:00 committed by GitHub
parent 5921043981
commit 5ff5209ed5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -425,6 +425,47 @@ public:
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}
template<class Seq>
struct acc2smem;
template <size_t... Seq>
struct acc2smem<cutlass::index_sequence<Seq...>> {
template<int Advance>
CUTLASS_DEVICE
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
(1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
template <typename SourceAspect>
@ -452,25 +493,8 @@ public:
__syncthreads();
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
{
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
this->warp_tile_iterator_.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
acc2smem<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator, this->warp_tile_iterator_);
//
// Load fragments from shared memory