parent
d8359c804b
commit
91b8de8d32
@ -270,8 +270,6 @@ public:
|
||||
|
||||
ThreadblockSwizzle block_mapping;
|
||||
|
||||
bool quick_dp;
|
||||
|
||||
void *barrier_workspace;
|
||||
void *partials_workspace;
|
||||
|
||||
@ -367,13 +365,6 @@ public:
|
||||
sm_occupancy,
|
||||
device_sms,
|
||||
avail_sms);
|
||||
|
||||
quick_dp =
|
||||
(block_mapping.sk_waves == 0) &&
|
||||
(mode == GemmUniversalMode::kGemm) &&
|
||||
!block_mapping.cohort_raster &&
|
||||
!EpilogueOutputOp(output_op).is_source_needed();
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -874,7 +865,7 @@ protected:
|
||||
threadblock_item_begin);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue.unified(
|
||||
epilogue(
|
||||
EpilogueOutputOp(params.output_op),
|
||||
iterator_D,
|
||||
accumulator_tile,
|
||||
@ -961,13 +952,14 @@ protected:
|
||||
AccumulatorTile accumulator_tile;
|
||||
accumulator_tile.clear();
|
||||
|
||||
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
||||
// Initialize MMA abstraction
|
||||
Mma mma(
|
||||
shared_storage.main_loop,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
||||
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
|
||||
|
||||
if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) ||
|
||||
@ -1020,29 +1012,27 @@ protected:
|
||||
void gemm()
|
||||
{
|
||||
// Initialize block's iteration range
|
||||
int tile_idx, block_iter_begin, block_iters_remaining;
|
||||
int tile_idx = 0;
|
||||
int block_iter_begin = 0;
|
||||
int block_iters_remaining = 0;
|
||||
|
||||
int block_idx = params.block_mapping.get_block_idx();
|
||||
|
||||
int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();
|
||||
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
|
||||
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
|
||||
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;
|
||||
|
||||
int block_idx = params.block_mapping.get_block_idx();
|
||||
if (block_idx < sk_padding_start_block_idx)
|
||||
{
|
||||
// This is a SK block
|
||||
int block_iter_end;
|
||||
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
|
||||
block_iters_remaining = block_iter_end - block_iter_begin;
|
||||
// Initialize tile work descriptor
|
||||
TileWorkDesc tile_work;
|
||||
|
||||
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
|
||||
}
|
||||
else if (block_idx < dp_start_block_idx)
|
||||
{
|
||||
// This is a filler block
|
||||
return;
|
||||
}
|
||||
else if (block_idx < reduce_start_block_idx)
|
||||
bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx);
|
||||
bool sk_block = (block_idx < sk_padding_start_block_idx);
|
||||
bool reduce_block = (block_idx >= reduce_start_block_idx) &&
|
||||
(block_idx < grid_padding_start_block_idx) &&
|
||||
(ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed);
|
||||
|
||||
if (dp_block)
|
||||
{
|
||||
// This is a DP block
|
||||
int dp_block_idx = block_idx - dp_start_block_idx;
|
||||
@ -1058,57 +1048,54 @@ protected:
|
||||
tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms;
|
||||
}
|
||||
|
||||
block_iter_begin = 0;
|
||||
block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;
|
||||
}
|
||||
|
||||
else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) &&
|
||||
(block_idx < grid_padding_start_block_idx))
|
||||
init_dp_tile_work(tile_work, tile_idx);
|
||||
|
||||
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
|
||||
if ((tile_idx < params.block_mapping.sk_tiles) ||
|
||||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
|
||||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
else if (sk_block)
|
||||
{
|
||||
// This is a reduction threadblock
|
||||
int reduce_block_idx = block_idx - reduce_start_block_idx;
|
||||
separate_reduction(reduce_block_idx);
|
||||
return;
|
||||
// This is a SK block
|
||||
int block_iter_end;
|
||||
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
|
||||
block_iters_remaining = block_iter_end - block_iter_begin;
|
||||
|
||||
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
|
||||
|
||||
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
|
||||
}
|
||||
else
|
||||
{
|
||||
// This is a filler block
|
||||
if (reduce_block)
|
||||
{
|
||||
// This is a reduction threadblock
|
||||
int reduce_block_idx = block_idx - reduce_start_block_idx;
|
||||
separate_reduction(reduce_block_idx);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Perform this block's share of work for this tile
|
||||
process_tile(
|
||||
tile_work,
|
||||
block_idx,
|
||||
dp_start_block_idx,
|
||||
block_iter_begin);
|
||||
|
||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||
|
||||
// Iteration-processing loop body
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (true)
|
||||
while (block_iters_remaining != 0)
|
||||
{
|
||||
// Initialize tile work descriptor
|
||||
TileWorkDesc tile_work;
|
||||
if (block_idx >= dp_start_block_idx)
|
||||
{
|
||||
init_dp_tile_work(tile_work, tile_idx);
|
||||
|
||||
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
|
||||
if ((tile_idx < params.block_mapping.sk_tiles) ||
|
||||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
|
||||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
|
||||
}
|
||||
|
||||
// Perform this block's share of work for this tile
|
||||
process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin);
|
||||
|
||||
// Update remaining work for this block
|
||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||
if (block_iters_remaining == 0) {
|
||||
// Done
|
||||
break;
|
||||
}
|
||||
|
||||
// Continue to next tile
|
||||
__syncthreads();
|
||||
|
||||
@ -1116,74 +1103,28 @@ protected:
|
||||
{
|
||||
// DP block consume their tiles at stride
|
||||
tile_idx += params.block_mapping.avail_sms;
|
||||
init_dp_tile_work(tile_work, tile_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
// SK blocks consume their tiles in backwards order
|
||||
tile_idx--;
|
||||
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
|
||||
}
|
||||
|
||||
// Perform this block's share of work for this tile
|
||||
process_tile(
|
||||
tile_work,
|
||||
block_idx,
|
||||
dp_start_block_idx,
|
||||
block_iter_begin);
|
||||
|
||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Executes one DP-only GEMM
|
||||
CUTLASS_DEVICE
|
||||
void gemm_dp()
|
||||
{
|
||||
int block_idx = blockIdx.x;
|
||||
int tile_idx = block_idx;
|
||||
|
||||
TileWorkDesc tile_work;
|
||||
tile_work.tile_idx = tile_idx;
|
||||
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
|
||||
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
|
||||
tile_work.k_begin = 0;
|
||||
tile_work.k_end = params.block_mapping.problem_size.k();
|
||||
tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx);
|
||||
|
||||
// Initialize input iterators
|
||||
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
|
||||
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);
|
||||
|
||||
// Initialize accumulators
|
||||
AccumulatorTile accumulator_tile;
|
||||
accumulator_tile.clear();
|
||||
|
||||
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
||||
Mma mma(
|
||||
shared_storage.main_loop,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
|
||||
|
||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
||||
|
||||
// Location of this tile in item-coords
|
||||
MatrixCoord threadblock_item_begin(
|
||||
tile_work.tiled_coord.m() * Mma::Shape::kM,
|
||||
tile_work.tiled_coord.n() * Mma::Shape::kN
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params.params_D,
|
||||
ptr_D,
|
||||
params.block_mapping.problem_size.mn(),
|
||||
thread_idx,
|
||||
threadblock_item_begin);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(
|
||||
EpilogueOutputOp(params.output_op),
|
||||
iterator_D,
|
||||
accumulator_tile);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
@ -1224,16 +1165,6 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void operator()()
|
||||
{
|
||||
#if (__CUDACC_VER_MAJOR__ > 10)
|
||||
if (params.quick_dp)
|
||||
{
|
||||
// Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray
|
||||
// modes will only be launched using a data-parallel configurations)
|
||||
gemm_dp();
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Generic SK code path
|
||||
gemm();
|
||||
|
||||
|
||||
@ -637,13 +637,6 @@ struct ThreadblockSwizzleStreamK {
|
||||
// Device-side interface
|
||||
//
|
||||
|
||||
/// Proves to the compiler that val is warp-uniform
|
||||
CUTLASS_DEVICE
|
||||
int uniform(int val) const
|
||||
{
|
||||
return __shfl_sync(0xffffffff, val, 0);
|
||||
}
|
||||
|
||||
/// Obtains number of threadblocks per GEMM
|
||||
CUTLASS_DEVICE
|
||||
int device_num_blocks() const
|
||||
@ -656,7 +649,7 @@ struct ThreadblockSwizzleStreamK {
|
||||
int get_sk_tile_idx(int iter) const
|
||||
{
|
||||
int tile_idx = div_mod_iters_per_tile.div(iter);
|
||||
return uniform(tile_idx);
|
||||
return tile_idx;
|
||||
}
|
||||
|
||||
/// Obtains the batch index
|
||||
@ -734,7 +727,7 @@ struct ThreadblockSwizzleStreamK {
|
||||
block_idx = (region * sk_blocks_per_region()) + block_in_region;
|
||||
}
|
||||
|
||||
return uniform(block_idx);
|
||||
return block_idx;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user