Add acc2smem in epilogue/threadblock/epilogue.h (#806)
This commit is contained in:
parent
5921043981
commit
5ff5209ed5
@ -425,6 +425,47 @@ public:
|
||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||
}
|
||||
|
||||
template<class Seq>
|
||||
struct acc2smem;
|
||||
|
||||
template <size_t... Seq>
|
||||
struct acc2smem<cutlass::index_sequence<Seq...>> {
|
||||
template<int Advance>
|
||||
CUTLASS_DEVICE
|
||||
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Advance; i++) {
|
||||
++accum_fragment_iterator;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
||||
(1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void push(size_t pos,
|
||||
AccumulatorFragmentIterator const &iterator_begin,
|
||||
WarpTileIterator &warp_tile_iterator) {
|
||||
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
template <typename SourceAspect>
|
||||
@ -452,25 +493,8 @@ public:
|
||||
|
||||
__syncthreads();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
|
||||
{
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
this->warp_tile_iterator_.store(accum_fragment);
|
||||
|
||||
if (p < Base::kFragmentsPerIteration - 1) {
|
||||
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset);
|
||||
}
|
||||
}
|
||||
|
||||
if (Base::kFragmentsPerIteration > 1) {
|
||||
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
||||
}
|
||||
|
||||
acc2smem<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
|
Loading…
Reference in New Issue
Block a user