Add acc2smem in epilogue/threadblock/epilogue.h (#806)
This commit is contained in:
parent
5921043981
commit
5ff5209ed5
@ -425,6 +425,47 @@ public:
|
|||||||
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class Seq>
|
||||||
|
struct acc2smem;
|
||||||
|
|
||||||
|
template <size_t... Seq>
|
||||||
|
struct acc2smem<cutlass::index_sequence<Seq...>> {
|
||||||
|
template<int Advance>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
|
||||||
|
WarpTileIterator &warp_tile_iterator) {
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < Advance; i++) {
|
||||||
|
++accum_fragment_iterator;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
|
||||||
|
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||||
|
|
||||||
|
accum_fragment_iterator.load(accum_fragment);
|
||||||
|
++accum_fragment_iterator;
|
||||||
|
|
||||||
|
warp_tile_iterator.store(accum_fragment);
|
||||||
|
if (p < Base::kFragmentsPerIteration - 1) {
|
||||||
|
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Base::kFragmentsPerIteration > 1) {
|
||||||
|
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
|
||||||
|
(1 - Base::kFragmentsPerIteration));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void push(size_t pos,
|
||||||
|
AccumulatorFragmentIterator const &iterator_begin,
|
||||||
|
WarpTileIterator &warp_tile_iterator) {
|
||||||
|
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
/// Streams the result to global memory
|
/// Streams the result to global memory
|
||||||
template <typename SourceAspect>
|
template <typename SourceAspect>
|
||||||
@ -452,25 +493,8 @@ public:
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
acc2smem<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
|
||||||
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
|
iter, accum_fragment_iterator, this->warp_tile_iterator_);
|
||||||
{
|
|
||||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
|
||||||
|
|
||||||
accum_fragment_iterator.load(accum_fragment);
|
|
||||||
++accum_fragment_iterator;
|
|
||||||
|
|
||||||
this->warp_tile_iterator_.store(accum_fragment);
|
|
||||||
|
|
||||||
if (p < Base::kFragmentsPerIteration - 1) {
|
|
||||||
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Base::kFragmentsPerIteration > 1) {
|
|
||||||
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Load fragments from shared memory
|
// Load fragments from shared memory
|
||||||
|
Loading…
Reference in New Issue
Block a user