streamk fix (#830)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2023-02-20 11:03:16 -05:00 committed by GitHub
parent d8359c804b
commit 91b8de8d32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 142 deletions

View File

@ -270,8 +270,6 @@ public:
ThreadblockSwizzle block_mapping;
bool quick_dp;
void *barrier_workspace;
void *partials_workspace;
@ -367,13 +365,6 @@ public:
sm_occupancy,
device_sms,
avail_sms);
quick_dp =
(block_mapping.sk_waves == 0) &&
(mode == GemmUniversalMode::kGemm) &&
!block_mapping.cohort_raster &&
!EpilogueOutputOp(output_op).is_source_needed();
}
@ -874,7 +865,7 @@ protected:
threadblock_item_begin);
// Execute the epilogue operator to update the destination tensor.
epilogue.unified(
epilogue(
EpilogueOutputOp(params.output_op),
iterator_D,
accumulator_tile,
@ -961,13 +952,14 @@ protected:
AccumulatorTile accumulator_tile;
accumulator_tile.clear();
// Perform this tile's range of multiply-accumulate (MAC) iterations
// Initialize MMA abstraction
Mma mma(
shared_storage.main_loop,
thread_idx,
warp_idx,
lane_idx);
// Perform this tile's range of multiply-accumulate (MAC) iterations
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) ||
@ -1020,29 +1012,27 @@ protected:
void gemm()
{
// Initialize block's iteration range
int tile_idx, block_iter_begin, block_iters_remaining;
int tile_idx = 0;
int block_iter_begin = 0;
int block_iters_remaining = 0;
int block_idx = params.block_mapping.get_block_idx();
int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;
int block_idx = params.block_mapping.get_block_idx();
if (block_idx < sk_padding_start_block_idx)
{
// This is a SK block
int block_iter_end;
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
block_iters_remaining = block_iter_end - block_iter_begin;
// Initialize tile work descriptor
TileWorkDesc tile_work;
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
}
else if (block_idx < dp_start_block_idx)
{
// This is a filler block
return;
}
else if (block_idx < reduce_start_block_idx)
bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx);
bool sk_block = (block_idx < sk_padding_start_block_idx);
bool reduce_block = (block_idx >= reduce_start_block_idx) &&
(block_idx < grid_padding_start_block_idx) &&
(ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed);
if (dp_block)
{
// This is a DP block
int dp_block_idx = block_idx - dp_start_block_idx;
@ -1058,32 +1048,8 @@ protected:
tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms;
}
block_iter_begin = 0;
block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;
}
else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) &&
(block_idx < grid_padding_start_block_idx))
{
// This is a reduction threadblock
int reduce_block_idx = block_idx - reduce_start_block_idx;
separate_reduction(reduce_block_idx);
return;
}
else
{
// This is a filler block
return;
}
// Iteration-processing loop body
CUTLASS_PRAGMA_NO_UNROLL
while (true)
{
// Initialize tile work descriptor
TileWorkDesc tile_work;
if (block_idx >= dp_start_block_idx)
{
init_dp_tile_work(tile_work, tile_idx);
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
@ -1091,24 +1057,45 @@ protected:
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
{
break;
return;
}
}
else if (sk_block)
{
// This is a SK block
int block_iter_end;
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
block_iters_remaining = block_iter_end - block_iter_begin;
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
else
{
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
if (reduce_block)
{
// This is a reduction threadblock
int reduce_block_idx = block_idx - reduce_start_block_idx;
separate_reduction(reduce_block_idx);
}
return;
}
// Perform this block's share of work for this tile
process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin);
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);
// Update remaining work for this block
block_iters_remaining -= tile_work.k_iters_remaining;
if (block_iters_remaining == 0) {
// Done
break;
}
// Iteration-processing loop body
CUTLASS_PRAGMA_NO_UNROLL
while (block_iters_remaining != 0)
{
// Continue to next tile
__syncthreads();
@ -1116,74 +1103,28 @@ protected:
{
// DP block consume their tiles at stride
tile_idx += params.block_mapping.avail_sms;
init_dp_tile_work(tile_work, tile_idx);
}
else
{
// SK blocks consume their tiles in backwards order
tile_idx--;
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);
block_iters_remaining -= tile_work.k_iters_remaining;
}
}
/// Executes one DP-only GEMM
CUTLASS_DEVICE
void gemm_dp()
{
int block_idx = blockIdx.x;
int tile_idx = block_idx;
TileWorkDesc tile_work;
tile_work.tile_idx = tile_idx;
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
tile_work.k_begin = 0;
tile_work.k_end = params.block_mapping.problem_size.k();
tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx);
// Initialize input iterators
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);
// Initialize accumulators
AccumulatorTile accumulator_tile;
accumulator_tile.clear();
// Perform this tile's range of multiply-accumulate (MAC) iterations
Mma mma(
shared_storage.main_loop,
thread_idx,
warp_idx,
lane_idx);
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
// Location of this tile in item-coords
MatrixCoord threadblock_item_begin(
tile_work.tiled_coord.m() * Mma::Shape::kM,
tile_work.tiled_coord.n() * Mma::Shape::kN
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
ptr_D,
params.block_mapping.problem_size.mn(),
thread_idx,
threadblock_item_begin);
// Execute the epilogue operator to update the destination tensor.
epilogue(
EpilogueOutputOp(params.output_op),
iterator_D,
accumulator_tile);
}
public:
//
@ -1224,16 +1165,6 @@ public:
CUTLASS_DEVICE
void operator()()
{
#if (__CUDACC_VER_MAJOR__ > 10)
if (params.quick_dp)
{
// Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray
// modes will only be launched using a data-parallel configurations)
gemm_dp();
return;
}
#endif
// Generic SK code path
gemm();

View File

@ -637,13 +637,6 @@ struct ThreadblockSwizzleStreamK {
// Device-side interface
//
/// Proves to the compiler that val is warp-uniform
CUTLASS_DEVICE
int uniform(int val) const
{
return __shfl_sync(0xffffffff, val, 0);
}
/// Obtains number of threadblocks per GEMM
CUTLASS_DEVICE
int device_num_blocks() const
@ -656,7 +649,7 @@ struct ThreadblockSwizzleStreamK {
int get_sk_tile_idx(int iter) const
{
int tile_idx = div_mod_iters_per_tile.div(iter);
return uniform(tile_idx);
return tile_idx;
}
/// Obtains the batch index
@ -734,7 +727,7 @@ struct ThreadblockSwizzleStreamK {
block_idx = (region * sk_blocks_per_region()) + block_in_region;
}
return uniform(block_idx);
return block_idx;
}