Updates for stream-k (#728)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM 2022-12-08 20:48:10 -08:00 committed by GitHub
parent 1d7772f218
commit 38193d76e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 96 additions and 84 deletions

View File

@ -146,7 +146,6 @@ struct StripedAccessType<
template <
int BlockThreads,
typename ArrayT,
typename T,
typename AccessT = StripedAccessType<ArrayT> >
struct BlockStriped
{
@ -156,7 +155,7 @@ struct BlockStriped
/// Load
CUTLASS_DEVICE
static void load(ArrayT &data, T *ptr, int thread_idx)
static void load(ArrayT &data, ArrayT *ptr, int thread_idx)
{
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
@ -169,7 +168,7 @@ struct BlockStriped
/// Load & Add
CUTLASS_DEVICE
static void load_add(ArrayT &data, T *ptr, int thread_idx)
static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx)
{
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
@ -185,7 +184,7 @@ struct BlockStriped
/// Store
CUTLASS_DEVICE
static void store(T *ptr, const ArrayT &data, int thread_idx)
static void store(ArrayT *ptr, const ArrayT &data, int thread_idx)
{
AccessT *access_output = reinterpret_cast<AccessT*>(ptr);
const AccessT *access_data = reinterpret_cast<const AccessT*>(&data);
@ -210,19 +209,24 @@ struct BlockStriped
template <
int BlockThreads,
typename ArrayT,
typename T>
struct BlockStripedReduce : BlockStriped<BlockThreads, ArrayT, T, T>
typename ElementT = typename StripedAccessType<ArrayT>::Element>
struct BlockStripedReduce :
BlockStriped<
BlockThreads,
ArrayT,
ElementT>
{
/// Reduce
CUTLASS_DEVICE
static void reduce(T *ptr, const ArrayT &data, int thread_idx)
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
{
cutlass::red<T> reduce;
const T *access_data = reinterpret_cast<const T*>(&data);
cutlass::red<ElementT> reduce;
ElementT *access_output = reinterpret_cast<ElementT*>(ptr);
const ElementT *access_data = reinterpret_cast<const ElementT*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < BlockStripedReduce::kStripes; ++i) {
reduce(ptr + (BlockThreads * i) + thread_idx, access_data[i]);
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
}
}
};
@ -234,13 +238,17 @@ struct BlockStripedReduce : BlockStriped<BlockThreads, ArrayT, T, T>
template <
int BlockThreads,
typename ArrayT>
struct BlockStripedReduce<BlockThreads, ArrayT, half_t> : BlockStriped<BlockThreads, ArrayT, half_t, half2>
struct BlockStripedReduce<BlockThreads, ArrayT, half_t> :
BlockStriped<
BlockThreads,
ArrayT,
half2>
{
static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length");
/// Reduce
CUTLASS_DEVICE
static void reduce(half_t *ptr, const ArrayT &data, int thread_idx)
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
{
cutlass::red<half2> reduce;
half2 *access_output = reinterpret_cast<half2*>(ptr);

View File

@ -222,7 +222,7 @@ public:
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
ElementAccumulator *element_workspace,
void *element_workspace,
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)

View File

@ -53,19 +53,20 @@ template <
typename Shape, ///< Shape of threadblock tile (concept: GemmShape)
int PartitionsK,
typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
typename AccumulatorFragmentIterator> ///< Fragment iterator selecting accumulators
typename AccumulatorFragmentIterator> ///< Iterator for enumerating fragments within the per-thread tile of raw accumulators
class EpilogueBaseStreamK
{
protected:
/// The complete warp-level accumulator tile
/// The per-thread tile of raw accumulators
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
/// Number of warps
using WarpCount = gemm::GemmShape<
Shape::kM / WarpMmaOperator::Shape::kM,
Shape::kN / WarpMmaOperator::Shape::kN, PartitionsK>;
Shape::kN / WarpMmaOperator::Shape::kN,
PartitionsK>;
/// Number of threads per block
static int const kBlockThreads = 32 * WarpCount::kCount;
@ -76,25 +77,26 @@ protected:
/// Fragment type used by the accumulator tile's fragment iterator
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
/// Block-striped transfer utility for sharing AccumulatorFragment
using BlockStripedT = BlockStriped<kBlockThreads, AccumulatorFragment, ElementAccumulator>;
/// Number of elements per fragment
static int const kFragmentElements = sizeof(AccumulatorFragment) / sizeof(ElementAccumulator);
public:
/// Number of fragments per accumulator tile
/// Number of AccumulatorTile fragments per thread
static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations;
/// Number of workspace accumulation elements shared per output tile
static int const kPeerAccumulators = WarpMmaOperator::Shape::kMN * WarpCount::kCount;
protected:
/// ElementAccumulator stride in the shared workspace between different peer blocks (two: each peer block can share accumulators for up to two tiles)
static const int kPeerStride = kPeerAccumulators * 2;
/// Number of AccumulatorTile fragments per block output tile
static int const kOutputTileFragments = kBlockThreads * kAccumulatorFragments;
/// Block-striped transfer utility for sharing AccumulatorFragment
using BlockStripedT = BlockStriped<kBlockThreads, AccumulatorFragment>;
/// AccumulatorFragment stride in the shared workspace between different peer blocks (each thread block can share accumulators for up to two block output tiles)
static const int kPeerFragmentStride = kOutputTileFragments * 2;
public:
/// Workspace bytes per thread block
static size_t const kWorkspaceBytesPerBlock =sizeof(AccumulatorFragment) * kPeerFragmentStride;
public:
@ -119,28 +121,29 @@ public:
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
ElementAccumulator *element_workspace)
void *workspace_ptr)
{
plus<AccumulatorFragment> add_fragments;
int accum_set_offset =
(peer_idx_begin * kPeerStride) +
(reduce_fragment_idx * kBlockThreads * kFragmentElements);
AccumulatorFragment *fragment_workspace = reinterpret_cast<AccumulatorFragment *>(workspace_ptr);
int fragment_offset = (peer_idx_begin * kPeerFragmentStride) + (reduce_fragment_idx * kBlockThreads);
// Load first peer fragment
BlockStripedT::load(accum_fragment, element_workspace + accum_set_offset, this->thread_idx);
BlockStripedT::load(accum_fragment, fragment_workspace + fragment_offset, this->thread_idx);
accum_set_offset += kPeerStride; // Move to next peer
accum_set_offset += kPeerAccumulators; // Move to non-starting accumulator set for peer
fragment_offset += kPeerFragmentStride; // Move to next peer
fragment_offset += kOutputTileFragments; // Move to the set of fragments for this peer's "non-started" output tile
// Reduce additional peer fragments
// Reduce fragments from additional peers
#pragma unroll 2
while (accum_set_offset < peer_idx_end * kPeerStride)
for (; fragment_offset < peer_idx_end * kPeerFragmentStride; fragment_offset += kPeerFragmentStride)
{
// Load peer fragment
AccumulatorFragment addend_fragment;
BlockStripedT::load(addend_fragment, element_workspace + accum_set_offset, this->thread_idx);
accum_set_offset += kPeerStride;
BlockStripedT::load(addend_fragment, fragment_workspace + fragment_offset, this->thread_idx);
// Add peer fragment
accum_fragment = add_fragments(accum_fragment, addend_fragment);
}
}
@ -150,19 +153,22 @@ public:
CUTLASS_DEVICE
void share(
int peer_idx,
ElementAccumulator *element_workspace, ///< Output pointer for writing this block's accumulator set to
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
bool started_tile)
void *workspace_ptr,
AccumulatorTile const &accumulators,
bool started_tile) ///< Whether this thread block computed the first work volume for the current output tile
{
int accum_set_offset = peer_idx * kPeerStride;
AccumulatorFragment *fragment_workspace = reinterpret_cast<AccumulatorFragment *>(workspace_ptr);
int fragment_offset = peer_idx * kPeerFragmentStride;
if (!started_tile) {
// Move to non-starting accumulator set
accum_set_offset += kPeerAccumulators;
// Move to the set of fragments for the "non-started" output tile
fragment_offset += kOutputTileFragments;
}
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
// Convert raw accumulator tile to fragments and store
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kAccumulatorFragments; ++iter)
{
@ -172,9 +178,9 @@ public:
++accum_fragment_iterator;
// Store accumulator fragment
BlockStripedT::store(element_workspace + accum_set_offset, accum_fragment, this->thread_idx);
BlockStripedT::store(fragment_workspace + fragment_offset, accum_fragment, this->thread_idx);
accum_set_offset += (kFragmentElements * kBlockThreads);
fragment_offset += kBlockThreads;
}
}

View File

@ -172,7 +172,7 @@ public:
int peer_idx_begin,
int peer_idx_end,
int reduce_fragment_idx,
ElementAccumulator *element_workspace,
void *element_workspace,
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)

View File

@ -77,7 +77,9 @@ public:
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
using ElementAccumulator = typename Mma::ElementC;
/// The per-thread tile of raw accumulators
using AccumulatorTile = typename Mma::FragmentC;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
@ -94,21 +96,18 @@ public:
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Number of workspace accumulation elements shared per per block
static int const kPeerAccumulators = Epilogue::kPeerAccumulators;
/// Number of fragments per (thread) accumulator tile
static int const kAccumulatorFragments = Epilogue::kAccumulatorFragments;
/// Number of numeric accumulation elements per fragment
static int const kAccumTileElements = sizeof(typename Mma::FragmentC) / sizeof(ElementAccumulator);
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Workspace bytes per thread block
static size_t const kWorkspaceBytesPerBlock =
__NV_STD_MAX(
kThreadCount * sizeof(AccumulatorTile),
Epilogue::kWorkspaceBytesPerBlock);
/// Block-striped reduction utility
using BlockStripedReduceT = BlockStripedReduce<kThreadCount, typename Mma::FragmentC, ElementAccumulator>;
using BlockStripedReduceT = BlockStripedReduce<kThreadCount, AccumulatorTile>;
@ -276,7 +275,7 @@ public:
int64_t batch_stride_D;
void *barrier_workspace;
ElementAccumulator *partials_workspace;
void *partials_workspace;
protected:
@ -305,11 +304,8 @@ public:
/// Get the workspace size needed for intermediate partial sums
size_t get_partials_workspace_size() const
{
// For atomic reduction, each SK-block can share one accumulator tile. For parallel reduction,
// each SK-block can share up to two accumulator tiles.
size_t tile_bytes_accumulators = sizeof(ElementAccumulator) * kPeerAccumulators * 2;
int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region;
return cacheline_align_up(tile_bytes_accumulators * sk_blocks);
return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks);
}
@ -388,7 +384,7 @@ public:
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
partials_workspace = reinterpret_cast<ElementAccumulator*>(ptr);
partials_workspace = ptr;
ptr += partials_workspace_bytes;
}
@ -400,7 +396,7 @@ public:
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
barrier_workspace = reinterpret_cast<ElementAccumulator*>(ptr);
barrier_workspace = ptr;
ptr += barrier_workspace_bytes;
}
@ -630,14 +626,12 @@ public:
return Status::kSuccess;
}
/// Determines whether the GEMM problem satisfies this kernel's
/// alignment requirements
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
protected:
//
@ -762,14 +756,16 @@ protected:
/// Share accumulators with peers
CUTLASS_DEVICE
void share_accumulators(typename Mma::FragmentC const &accumulator_tile, int first_block_idx)
void share_accumulators(AccumulatorTile const &accumulator_tile, int first_block_idx)
{
int block_tile_offset = first_block_idx * kPeerAccumulators;
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
int accum_tile_offset = first_block_idx * kThreadCount;
if (block_idx == first_block_idx)
{
// First peer initializes the workspace partials
BlockStripedReduceT::store(params.partials_workspace + block_tile_offset, accumulator_tile, thread_idx);
BlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx);
}
else
{
@ -787,7 +783,7 @@ protected:
}
// Perform reduction in workspace
BlockStripedReduceT::reduce(params.partials_workspace + block_tile_offset, accumulator_tile, thread_idx);
BlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx);
}
// Signal our arrival
@ -797,17 +793,19 @@ protected:
/// Acquire accumulators from peers
CUTLASS_DEVICE
void acquire_accumulators_atomic(
typename Mma::FragmentC &accumulator_tile,
void acquire_accumulators(
AccumulatorTile &accumulator_tile,
int first_block_idx)
{
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
// Wait for arrival
int num_carry_in = block_idx - first_block_idx;
Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in);
// Load and add peer-partials accumulator tile to local accumulator tile
int block_tile_offset = first_block_idx * kPeerAccumulators;
BlockStripedReduceT::load_add(accumulator_tile, params.partials_workspace + block_tile_offset, thread_idx);
int accum_tile_offset = first_block_idx * kThreadCount;
BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx);
}
@ -815,7 +813,7 @@ protected:
CUTLASS_DEVICE
void do_epilogue(
TileWorkDesc &tile_work,
typename Mma::FragmentC &accumulator_tile)
AccumulatorTile &accumulator_tile)
{
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
@ -867,8 +865,8 @@ protected:
int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx;
// Reduce by sk-tile (every tile contributed to by one or more blocks)
reduce_tile_idx = reduce_idx / kAccumulatorFragments;
reduce_fragment_idx = reduce_idx % kAccumulatorFragments;
reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments;
reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments;
int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile;
int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile - 1;
@ -882,7 +880,7 @@ protected:
Barrier::wait_eq_reset(
params.barrier_workspace,
thread_idx,
(reduce_tile_idx * kAccumulatorFragments) + reduce_fragment_idx,
(reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx,
num_peers);
/// The location of this tile (in threadblock-tile coordinates) in the output matrix
@ -946,7 +944,7 @@ protected:
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work);
// Initialize accumulators
typename Mma::FragmentC accumulator_tile;
AccumulatorTile accumulator_tile;
accumulator_tile.clear();
// Perform this tile's range of multiply-accumulate (MAC) iterations
@ -978,7 +976,7 @@ protected:
if (!tile_work.tile_started())
{
// A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks
acquire_accumulators_atomic(accumulator_tile, first_block_idx);
acquire_accumulators(accumulator_tile, first_block_idx);
}
do_epilogue(tile_work, accumulator_tile);
@ -997,8 +995,8 @@ protected:
Barrier::arrive_range_inc(
params.barrier_workspace,
thread_idx,
tile_work.tile_idx * kAccumulatorFragments,
kAccumulatorFragments);
tile_work.tile_idx * Epilogue::kAccumulatorFragments,
Epilogue::kAccumulatorFragments);
}
}