Updates for stream-k (#728)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
parent
1d7772f218
commit
38193d76e3
@ -146,7 +146,6 @@ struct StripedAccessType<
|
||||
template <
|
||||
int BlockThreads,
|
||||
typename ArrayT,
|
||||
typename T,
|
||||
typename AccessT = StripedAccessType<ArrayT> >
|
||||
struct BlockStriped
|
||||
{
|
||||
@ -156,7 +155,7 @@ struct BlockStriped
|
||||
|
||||
/// Load
|
||||
CUTLASS_DEVICE
|
||||
static void load(ArrayT &data, T *ptr, int thread_idx)
|
||||
static void load(ArrayT &data, ArrayT *ptr, int thread_idx)
|
||||
{
|
||||
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
|
||||
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
|
||||
@ -169,7 +168,7 @@ struct BlockStriped
|
||||
|
||||
/// Load & Add
|
||||
CUTLASS_DEVICE
|
||||
static void load_add(ArrayT &data, T *ptr, int thread_idx)
|
||||
static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx)
|
||||
{
|
||||
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
|
||||
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
|
||||
@ -185,7 +184,7 @@ struct BlockStriped
|
||||
|
||||
/// Store
|
||||
CUTLASS_DEVICE
|
||||
static void store(T *ptr, const ArrayT &data, int thread_idx)
|
||||
static void store(ArrayT *ptr, const ArrayT &data, int thread_idx)
|
||||
{
|
||||
AccessT *access_output = reinterpret_cast<AccessT*>(ptr);
|
||||
const AccessT *access_data = reinterpret_cast<const AccessT*>(&data);
|
||||
@ -210,19 +209,24 @@ struct BlockStriped
|
||||
template <
|
||||
int BlockThreads,
|
||||
typename ArrayT,
|
||||
typename T>
|
||||
struct BlockStripedReduce : BlockStriped<BlockThreads, ArrayT, T, T>
|
||||
typename ElementT = typename StripedAccessType<ArrayT>::Element>
|
||||
struct BlockStripedReduce :
|
||||
BlockStriped<
|
||||
BlockThreads,
|
||||
ArrayT,
|
||||
ElementT>
|
||||
{
|
||||
/// Reduce
|
||||
CUTLASS_DEVICE
|
||||
static void reduce(T *ptr, const ArrayT &data, int thread_idx)
|
||||
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
|
||||
{
|
||||
cutlass::red<T> reduce;
|
||||
const T *access_data = reinterpret_cast<const T*>(&data);
|
||||
cutlass::red<ElementT> reduce;
|
||||
ElementT *access_output = reinterpret_cast<ElementT*>(ptr);
|
||||
const ElementT *access_data = reinterpret_cast<const ElementT*>(&data);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < BlockStripedReduce::kStripes; ++i) {
|
||||
reduce(ptr + (BlockThreads * i) + thread_idx, access_data[i]);
|
||||
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -234,13 +238,17 @@ struct BlockStripedReduce : BlockStriped<BlockThreads, ArrayT, T, T>
|
||||
template <
|
||||
int BlockThreads,
|
||||
typename ArrayT>
|
||||
struct BlockStripedReduce<BlockThreads, ArrayT, half_t> : BlockStriped<BlockThreads, ArrayT, half_t, half2>
|
||||
struct BlockStripedReduce<BlockThreads, ArrayT, half_t> :
|
||||
BlockStriped<
|
||||
BlockThreads,
|
||||
ArrayT,
|
||||
half2>
|
||||
{
|
||||
static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length");
|
||||
|
||||
/// Reduce
|
||||
CUTLASS_DEVICE
|
||||
static void reduce(half_t *ptr, const ArrayT &data, int thread_idx)
|
||||
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
|
||||
{
|
||||
cutlass::red<half2> reduce;
|
||||
half2 *access_output = reinterpret_cast<half2*>(ptr);
|
||||
|
@ -222,7 +222,7 @@ public:
|
||||
int peer_idx_begin,
|
||||
int peer_idx_end,
|
||||
int reduce_fragment_idx,
|
||||
ElementAccumulator *element_workspace,
|
||||
void *element_workspace,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
@ -53,19 +53,20 @@ template <
|
||||
typename Shape, ///< Shape of threadblock tile (concept: GemmShape)
|
||||
int PartitionsK,
|
||||
typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp)
|
||||
typename AccumulatorFragmentIterator> ///< Fragment iterator selecting accumulators
|
||||
typename AccumulatorFragmentIterator> ///< Iterator for enumerating fragments within the per-thread tile of raw accumulators
|
||||
class EpilogueBaseStreamK
|
||||
{
|
||||
|
||||
protected:
|
||||
|
||||
/// The complete warp-level accumulator tile
|
||||
/// The per-thread tile of raw accumulators
|
||||
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount = gemm::GemmShape<
|
||||
Shape::kM / WarpMmaOperator::Shape::kM,
|
||||
Shape::kN / WarpMmaOperator::Shape::kN, PartitionsK>;
|
||||
Shape::kN / WarpMmaOperator::Shape::kN,
|
||||
PartitionsK>;
|
||||
|
||||
/// Number of threads per block
|
||||
static int const kBlockThreads = 32 * WarpCount::kCount;
|
||||
@ -76,25 +77,26 @@ protected:
|
||||
/// Fragment type used by the accumulator tile's fragment iterator
|
||||
using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment;
|
||||
|
||||
/// Block-striped transfer utility for sharing AccumulatorFragment
|
||||
using BlockStripedT = BlockStriped<kBlockThreads, AccumulatorFragment, ElementAccumulator>;
|
||||
|
||||
/// Number of elements per fragment
|
||||
static int const kFragmentElements = sizeof(AccumulatorFragment) / sizeof(ElementAccumulator);
|
||||
|
||||
public:
|
||||
|
||||
/// Number of fragments per accumulator tile
|
||||
/// Number of AccumulatorTile fragments per thread
|
||||
static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations;
|
||||
|
||||
/// Number of workspace accumulation elements shared per output tile
|
||||
static int const kPeerAccumulators = WarpMmaOperator::Shape::kMN * WarpCount::kCount;
|
||||
|
||||
protected:
|
||||
|
||||
/// ElementAccumulator stride in the shared workspace between different peer blocks (two: each peer block can share accumulators for up to two tiles)
|
||||
static const int kPeerStride = kPeerAccumulators * 2;
|
||||
/// Number of AccumulatorTile fragments per block output tile
|
||||
static int const kOutputTileFragments = kBlockThreads * kAccumulatorFragments;
|
||||
|
||||
/// Block-striped transfer utility for sharing AccumulatorFragment
|
||||
using BlockStripedT = BlockStriped<kBlockThreads, AccumulatorFragment>;
|
||||
|
||||
/// AccumulatorFragment stride in the shared workspace between different peer blocks (each thread block can share accumulators for up to two block output tiles)
|
||||
static const int kPeerFragmentStride = kOutputTileFragments * 2;
|
||||
|
||||
public:
|
||||
|
||||
/// Workspace bytes per thread block
|
||||
static size_t const kWorkspaceBytesPerBlock =sizeof(AccumulatorFragment) * kPeerFragmentStride;
|
||||
|
||||
public:
|
||||
|
||||
@ -119,28 +121,29 @@ public:
|
||||
int peer_idx_begin,
|
||||
int peer_idx_end,
|
||||
int reduce_fragment_idx,
|
||||
ElementAccumulator *element_workspace)
|
||||
void *workspace_ptr)
|
||||
{
|
||||
plus<AccumulatorFragment> add_fragments;
|
||||
|
||||
int accum_set_offset =
|
||||
(peer_idx_begin * kPeerStride) +
|
||||
(reduce_fragment_idx * kBlockThreads * kFragmentElements);
|
||||
AccumulatorFragment *fragment_workspace = reinterpret_cast<AccumulatorFragment *>(workspace_ptr);
|
||||
|
||||
int fragment_offset = (peer_idx_begin * kPeerFragmentStride) + (reduce_fragment_idx * kBlockThreads);
|
||||
|
||||
// Load first peer fragment
|
||||
BlockStripedT::load(accum_fragment, element_workspace + accum_set_offset, this->thread_idx);
|
||||
BlockStripedT::load(accum_fragment, fragment_workspace + fragment_offset, this->thread_idx);
|
||||
|
||||
accum_set_offset += kPeerStride; // Move to next peer
|
||||
accum_set_offset += kPeerAccumulators; // Move to non-starting accumulator set for peer
|
||||
fragment_offset += kPeerFragmentStride; // Move to next peer
|
||||
fragment_offset += kOutputTileFragments; // Move to the set of fragments for this peer's "non-started" output tile
|
||||
|
||||
// Reduce additional peer fragments
|
||||
// Reduce fragments from additional peers
|
||||
#pragma unroll 2
|
||||
while (accum_set_offset < peer_idx_end * kPeerStride)
|
||||
for (; fragment_offset < peer_idx_end * kPeerFragmentStride; fragment_offset += kPeerFragmentStride)
|
||||
{
|
||||
// Load peer fragment
|
||||
AccumulatorFragment addend_fragment;
|
||||
BlockStripedT::load(addend_fragment, element_workspace + accum_set_offset, this->thread_idx);
|
||||
accum_set_offset += kPeerStride;
|
||||
BlockStripedT::load(addend_fragment, fragment_workspace + fragment_offset, this->thread_idx);
|
||||
|
||||
// Add peer fragment
|
||||
accum_fragment = add_fragments(accum_fragment, addend_fragment);
|
||||
}
|
||||
}
|
||||
@ -150,19 +153,22 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void share(
|
||||
int peer_idx,
|
||||
ElementAccumulator *element_workspace, ///< Output pointer for writing this block's accumulator set to
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
bool started_tile)
|
||||
void *workspace_ptr,
|
||||
AccumulatorTile const &accumulators,
|
||||
bool started_tile) ///< Whether this thread block computed the first work volume for the current output tile
|
||||
{
|
||||
int accum_set_offset = peer_idx * kPeerStride;
|
||||
AccumulatorFragment *fragment_workspace = reinterpret_cast<AccumulatorFragment *>(workspace_ptr);
|
||||
|
||||
int fragment_offset = peer_idx * kPeerFragmentStride;
|
||||
|
||||
if (!started_tile) {
|
||||
// Move to non-starting accumulator set
|
||||
accum_set_offset += kPeerAccumulators;
|
||||
// Move to the set of fragments for the "non-started" output tile
|
||||
fragment_offset += kOutputTileFragments;
|
||||
}
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
// Convert raw accumulator tile to fragments and store
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter = 0; iter < kAccumulatorFragments; ++iter)
|
||||
{
|
||||
@ -172,9 +178,9 @@ public:
|
||||
++accum_fragment_iterator;
|
||||
|
||||
// Store accumulator fragment
|
||||
BlockStripedT::store(element_workspace + accum_set_offset, accum_fragment, this->thread_idx);
|
||||
BlockStripedT::store(fragment_workspace + fragment_offset, accum_fragment, this->thread_idx);
|
||||
|
||||
accum_set_offset += (kFragmentElements * kBlockThreads);
|
||||
fragment_offset += kBlockThreads;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -172,7 +172,7 @@ public:
|
||||
int peer_idx_begin,
|
||||
int peer_idx_end,
|
||||
int reduce_fragment_idx,
|
||||
ElementAccumulator *element_workspace,
|
||||
void *element_workspace,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
@ -77,7 +77,9 @@ public:
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
using ElementAccumulator = typename Mma::ElementC;
|
||||
|
||||
/// The per-thread tile of raw accumulators
|
||||
using AccumulatorTile = typename Mma::FragmentC;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
@ -94,21 +96,18 @@ public:
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Number of workspace accumulation elements shared per per block
|
||||
static int const kPeerAccumulators = Epilogue::kPeerAccumulators;
|
||||
|
||||
/// Number of fragments per (thread) accumulator tile
|
||||
static int const kAccumulatorFragments = Epilogue::kAccumulatorFragments;
|
||||
|
||||
/// Number of numeric accumulation elements per fragment
|
||||
static int const kAccumTileElements = sizeof(typename Mma::FragmentC) / sizeof(ElementAccumulator);
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Workspace bytes per thread block
|
||||
static size_t const kWorkspaceBytesPerBlock =
|
||||
__NV_STD_MAX(
|
||||
kThreadCount * sizeof(AccumulatorTile),
|
||||
Epilogue::kWorkspaceBytesPerBlock);
|
||||
|
||||
/// Block-striped reduction utility
|
||||
using BlockStripedReduceT = BlockStripedReduce<kThreadCount, typename Mma::FragmentC, ElementAccumulator>;
|
||||
using BlockStripedReduceT = BlockStripedReduce<kThreadCount, AccumulatorTile>;
|
||||
|
||||
|
||||
|
||||
@ -276,7 +275,7 @@ public:
|
||||
int64_t batch_stride_D;
|
||||
|
||||
void *barrier_workspace;
|
||||
ElementAccumulator *partials_workspace;
|
||||
void *partials_workspace;
|
||||
|
||||
protected:
|
||||
|
||||
@ -305,11 +304,8 @@ public:
|
||||
/// Get the workspace size needed for intermediate partial sums
|
||||
size_t get_partials_workspace_size() const
|
||||
{
|
||||
// For atomic reduction, each SK-block can share one accumulator tile. For parallel reduction,
|
||||
// each SK-block can share up to two accumulator tiles.
|
||||
size_t tile_bytes_accumulators = sizeof(ElementAccumulator) * kPeerAccumulators * 2;
|
||||
int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region;
|
||||
return cacheline_align_up(tile_bytes_accumulators * sk_blocks);
|
||||
return cacheline_align_up(kWorkspaceBytesPerBlock * sk_blocks);
|
||||
}
|
||||
|
||||
|
||||
@ -388,7 +384,7 @@ public:
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
partials_workspace = reinterpret_cast<ElementAccumulator*>(ptr);
|
||||
partials_workspace = ptr;
|
||||
ptr += partials_workspace_bytes;
|
||||
}
|
||||
|
||||
@ -400,7 +396,7 @@ public:
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
barrier_workspace = reinterpret_cast<ElementAccumulator*>(ptr);
|
||||
barrier_workspace = ptr;
|
||||
ptr += barrier_workspace_bytes;
|
||||
}
|
||||
|
||||
@ -630,14 +626,12 @@ public:
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
|
||||
/// Determines whether the GEMM problem satisfies this kernel's
|
||||
/// alignment requirements
|
||||
static Status can_implement(Arguments const &args) {
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
@ -762,14 +756,16 @@ protected:
|
||||
|
||||
/// Share accumulators with peers
|
||||
CUTLASS_DEVICE
|
||||
void share_accumulators(typename Mma::FragmentC const &accumulator_tile, int first_block_idx)
|
||||
void share_accumulators(AccumulatorTile const &accumulator_tile, int first_block_idx)
|
||||
{
|
||||
int block_tile_offset = first_block_idx * kPeerAccumulators;
|
||||
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
|
||||
|
||||
int accum_tile_offset = first_block_idx * kThreadCount;
|
||||
|
||||
if (block_idx == first_block_idx)
|
||||
{
|
||||
// First peer initializes the workspace partials
|
||||
BlockStripedReduceT::store(params.partials_workspace + block_tile_offset, accumulator_tile, thread_idx);
|
||||
BlockStripedReduceT::store(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -787,7 +783,7 @@ protected:
|
||||
}
|
||||
|
||||
// Perform reduction in workspace
|
||||
BlockStripedReduceT::reduce(params.partials_workspace + block_tile_offset, accumulator_tile, thread_idx);
|
||||
BlockStripedReduceT::reduce(accum_tile_workspace + accum_tile_offset, accumulator_tile, thread_idx);
|
||||
}
|
||||
|
||||
// Signal our arrival
|
||||
@ -797,17 +793,19 @@ protected:
|
||||
|
||||
/// Acquire accumulators from peers
|
||||
CUTLASS_DEVICE
|
||||
void acquire_accumulators_atomic(
|
||||
typename Mma::FragmentC &accumulator_tile,
|
||||
void acquire_accumulators(
|
||||
AccumulatorTile &accumulator_tile,
|
||||
int first_block_idx)
|
||||
{
|
||||
AccumulatorTile *accum_tile_workspace = reinterpret_cast<AccumulatorTile *>(params.partials_workspace);
|
||||
|
||||
// Wait for arrival
|
||||
int num_carry_in = block_idx - first_block_idx;
|
||||
Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in);
|
||||
|
||||
// Load and add peer-partials accumulator tile to local accumulator tile
|
||||
int block_tile_offset = first_block_idx * kPeerAccumulators;
|
||||
BlockStripedReduceT::load_add(accumulator_tile, params.partials_workspace + block_tile_offset, thread_idx);
|
||||
int accum_tile_offset = first_block_idx * kThreadCount;
|
||||
BlockStripedReduceT::load_add(accumulator_tile, accum_tile_workspace + accum_tile_offset, thread_idx);
|
||||
}
|
||||
|
||||
|
||||
@ -815,7 +813,7 @@ protected:
|
||||
CUTLASS_DEVICE
|
||||
void do_epilogue(
|
||||
TileWorkDesc &tile_work,
|
||||
typename Mma::FragmentC &accumulator_tile)
|
||||
AccumulatorTile &accumulator_tile)
|
||||
{
|
||||
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
@ -867,8 +865,8 @@ protected:
|
||||
int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx;
|
||||
|
||||
// Reduce by sk-tile (every tile contributed to by one or more blocks)
|
||||
reduce_tile_idx = reduce_idx / kAccumulatorFragments;
|
||||
reduce_fragment_idx = reduce_idx % kAccumulatorFragments;
|
||||
reduce_tile_idx = reduce_idx / Epilogue::kAccumulatorFragments;
|
||||
reduce_fragment_idx = reduce_idx % Epilogue::kAccumulatorFragments;
|
||||
|
||||
int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile;
|
||||
int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile - 1;
|
||||
@ -882,7 +880,7 @@ protected:
|
||||
Barrier::wait_eq_reset(
|
||||
params.barrier_workspace,
|
||||
thread_idx,
|
||||
(reduce_tile_idx * kAccumulatorFragments) + reduce_fragment_idx,
|
||||
(reduce_tile_idx * Epilogue::kAccumulatorFragments) + reduce_fragment_idx,
|
||||
num_peers);
|
||||
|
||||
/// The location of this tile (in threadblock-tile coordinates) in the output matrix
|
||||
@ -946,7 +944,7 @@ protected:
|
||||
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work);
|
||||
|
||||
// Initialize accumulators
|
||||
typename Mma::FragmentC accumulator_tile;
|
||||
AccumulatorTile accumulator_tile;
|
||||
accumulator_tile.clear();
|
||||
|
||||
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
||||
@ -978,7 +976,7 @@ protected:
|
||||
if (!tile_work.tile_started())
|
||||
{
|
||||
// A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks
|
||||
acquire_accumulators_atomic(accumulator_tile, first_block_idx);
|
||||
acquire_accumulators(accumulator_tile, first_block_idx);
|
||||
}
|
||||
|
||||
do_epilogue(tile_work, accumulator_tile);
|
||||
@ -997,8 +995,8 @@ protected:
|
||||
Barrier::arrive_range_inc(
|
||||
params.barrier_workspace,
|
||||
thread_idx,
|
||||
tile_work.tile_idx * kAccumulatorFragments,
|
||||
kAccumulatorFragments);
|
||||
tile_work.tile_idx * Epilogue::kAccumulatorFragments,
|
||||
Epilogue::kAccumulatorFragments);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user