parent
d8359c804b
commit
91b8de8d32
@ -270,8 +270,6 @@ public:
|
|||||||
|
|
||||||
ThreadblockSwizzle block_mapping;
|
ThreadblockSwizzle block_mapping;
|
||||||
|
|
||||||
bool quick_dp;
|
|
||||||
|
|
||||||
void *barrier_workspace;
|
void *barrier_workspace;
|
||||||
void *partials_workspace;
|
void *partials_workspace;
|
||||||
|
|
||||||
@ -367,13 +365,6 @@ public:
|
|||||||
sm_occupancy,
|
sm_occupancy,
|
||||||
device_sms,
|
device_sms,
|
||||||
avail_sms);
|
avail_sms);
|
||||||
|
|
||||||
quick_dp =
|
|
||||||
(block_mapping.sk_waves == 0) &&
|
|
||||||
(mode == GemmUniversalMode::kGemm) &&
|
|
||||||
!block_mapping.cohort_raster &&
|
|
||||||
!EpilogueOutputOp(output_op).is_source_needed();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -874,7 +865,7 @@ protected:
|
|||||||
threadblock_item_begin);
|
threadblock_item_begin);
|
||||||
|
|
||||||
// Execute the epilogue operator to update the destination tensor.
|
// Execute the epilogue operator to update the destination tensor.
|
||||||
epilogue.unified(
|
epilogue(
|
||||||
EpilogueOutputOp(params.output_op),
|
EpilogueOutputOp(params.output_op),
|
||||||
iterator_D,
|
iterator_D,
|
||||||
accumulator_tile,
|
accumulator_tile,
|
||||||
@ -961,13 +952,14 @@ protected:
|
|||||||
AccumulatorTile accumulator_tile;
|
AccumulatorTile accumulator_tile;
|
||||||
accumulator_tile.clear();
|
accumulator_tile.clear();
|
||||||
|
|
||||||
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
// Initialize MMA abstraction
|
||||||
Mma mma(
|
Mma mma(
|
||||||
shared_storage.main_loop,
|
shared_storage.main_loop,
|
||||||
thread_idx,
|
thread_idx,
|
||||||
warp_idx,
|
warp_idx,
|
||||||
lane_idx);
|
lane_idx);
|
||||||
|
|
||||||
|
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
||||||
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
|
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
|
||||||
|
|
||||||
if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) ||
|
if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) ||
|
||||||
@ -1020,29 +1012,27 @@ protected:
|
|||||||
void gemm()
|
void gemm()
|
||||||
{
|
{
|
||||||
// Initialize block's iteration range
|
// Initialize block's iteration range
|
||||||
int tile_idx, block_iter_begin, block_iters_remaining;
|
int tile_idx = 0;
|
||||||
|
int block_iter_begin = 0;
|
||||||
|
int block_iters_remaining = 0;
|
||||||
|
|
||||||
|
int block_idx = params.block_mapping.get_block_idx();
|
||||||
|
|
||||||
int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();
|
int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region();
|
||||||
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
|
int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms;
|
||||||
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
|
int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks;
|
||||||
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;
|
int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks;
|
||||||
|
|
||||||
int block_idx = params.block_mapping.get_block_idx();
|
// Initialize tile work descriptor
|
||||||
if (block_idx < sk_padding_start_block_idx)
|
TileWorkDesc tile_work;
|
||||||
{
|
|
||||||
// This is a SK block
|
|
||||||
int block_iter_end;
|
|
||||||
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
|
|
||||||
block_iters_remaining = block_iter_end - block_iter_begin;
|
|
||||||
|
|
||||||
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
|
bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx);
|
||||||
}
|
bool sk_block = (block_idx < sk_padding_start_block_idx);
|
||||||
else if (block_idx < dp_start_block_idx)
|
bool reduce_block = (block_idx >= reduce_start_block_idx) &&
|
||||||
{
|
(block_idx < grid_padding_start_block_idx) &&
|
||||||
// This is a filler block
|
(ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed);
|
||||||
return;
|
|
||||||
}
|
if (dp_block)
|
||||||
else if (block_idx < reduce_start_block_idx)
|
|
||||||
{
|
{
|
||||||
// This is a DP block
|
// This is a DP block
|
||||||
int dp_block_idx = block_idx - dp_start_block_idx;
|
int dp_block_idx = block_idx - dp_start_block_idx;
|
||||||
@ -1058,32 +1048,8 @@ protected:
|
|||||||
tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms;
|
tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms;
|
||||||
}
|
}
|
||||||
|
|
||||||
block_iter_begin = 0;
|
|
||||||
block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;
|
block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment;
|
||||||
}
|
|
||||||
|
|
||||||
else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) &&
|
|
||||||
(block_idx < grid_padding_start_block_idx))
|
|
||||||
{
|
|
||||||
// This is a reduction threadblock
|
|
||||||
int reduce_block_idx = block_idx - reduce_start_block_idx;
|
|
||||||
separate_reduction(reduce_block_idx);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// This is a filler block
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iteration-processing loop body
|
|
||||||
CUTLASS_PRAGMA_NO_UNROLL
|
|
||||||
while (true)
|
|
||||||
{
|
|
||||||
// Initialize tile work descriptor
|
|
||||||
TileWorkDesc tile_work;
|
|
||||||
if (block_idx >= dp_start_block_idx)
|
|
||||||
{
|
|
||||||
init_dp_tile_work(tile_work, tile_idx);
|
init_dp_tile_work(tile_work, tile_idx);
|
||||||
|
|
||||||
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
|
// DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1)
|
||||||
@ -1091,24 +1057,45 @@ protected:
|
|||||||
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
|
(tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) ||
|
||||||
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
|
(tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n()))
|
||||||
{
|
{
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if (sk_block)
|
||||||
|
{
|
||||||
|
// This is a SK block
|
||||||
|
int block_iter_end;
|
||||||
|
params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end);
|
||||||
|
block_iters_remaining = block_iter_end - block_iter_begin;
|
||||||
|
|
||||||
|
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
|
||||||
|
|
||||||
|
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
|
if (reduce_block)
|
||||||
|
{
|
||||||
|
// This is a reduction threadblock
|
||||||
|
int reduce_block_idx = block_idx - reduce_start_block_idx;
|
||||||
|
separate_reduction(reduce_block_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform this block's share of work for this tile
|
// Perform this block's share of work for this tile
|
||||||
process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin);
|
process_tile(
|
||||||
|
tile_work,
|
||||||
|
block_idx,
|
||||||
|
dp_start_block_idx,
|
||||||
|
block_iter_begin);
|
||||||
|
|
||||||
// Update remaining work for this block
|
|
||||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||||
if (block_iters_remaining == 0) {
|
|
||||||
// Done
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Iteration-processing loop body
|
||||||
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
|
while (block_iters_remaining != 0)
|
||||||
|
{
|
||||||
// Continue to next tile
|
// Continue to next tile
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -1116,74 +1103,28 @@ protected:
|
|||||||
{
|
{
|
||||||
// DP block consume their tiles at stride
|
// DP block consume their tiles at stride
|
||||||
tile_idx += params.block_mapping.avail_sms;
|
tile_idx += params.block_mapping.avail_sms;
|
||||||
|
init_dp_tile_work(tile_work, tile_idx);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
// SK blocks consume their tiles in backwards order
|
// SK blocks consume their tiles in backwards order
|
||||||
tile_idx--;
|
tile_idx--;
|
||||||
|
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Perform this block's share of work for this tile
|
||||||
|
process_tile(
|
||||||
|
tile_work,
|
||||||
|
block_idx,
|
||||||
|
dp_start_block_idx,
|
||||||
|
block_iter_begin);
|
||||||
|
|
||||||
|
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Executes one DP-only GEMM
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
void gemm_dp()
|
|
||||||
{
|
|
||||||
int block_idx = blockIdx.x;
|
|
||||||
int tile_idx = block_idx;
|
|
||||||
|
|
||||||
TileWorkDesc tile_work;
|
|
||||||
tile_work.tile_idx = tile_idx;
|
|
||||||
tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile();
|
|
||||||
tile_work.k_iters_remaining = params.block_mapping.iters_per_tile();
|
|
||||||
tile_work.k_begin = 0;
|
|
||||||
tile_work.k_end = params.block_mapping.problem_size.k();
|
|
||||||
tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx);
|
|
||||||
|
|
||||||
// Initialize input iterators
|
|
||||||
typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode);
|
|
||||||
typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode);
|
|
||||||
|
|
||||||
// Initialize accumulators
|
|
||||||
AccumulatorTile accumulator_tile;
|
|
||||||
accumulator_tile.clear();
|
|
||||||
|
|
||||||
// Perform this tile's range of multiply-accumulate (MAC) iterations
|
|
||||||
Mma mma(
|
|
||||||
shared_storage.main_loop,
|
|
||||||
thread_idx,
|
|
||||||
warp_idx,
|
|
||||||
lane_idx);
|
|
||||||
|
|
||||||
mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile);
|
|
||||||
|
|
||||||
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
|
||||||
|
|
||||||
// Location of this tile in item-coords
|
|
||||||
MatrixCoord threadblock_item_begin(
|
|
||||||
tile_work.tiled_coord.m() * Mma::Shape::kM,
|
|
||||||
tile_work.tiled_coord.n() * Mma::Shape::kN
|
|
||||||
);
|
|
||||||
|
|
||||||
// Tile iterator writing to destination tensor.
|
|
||||||
typename Epilogue::OutputTileIterator iterator_D(
|
|
||||||
params.params_D,
|
|
||||||
ptr_D,
|
|
||||||
params.block_mapping.problem_size.mn(),
|
|
||||||
thread_idx,
|
|
||||||
threadblock_item_begin);
|
|
||||||
|
|
||||||
// Execute the epilogue operator to update the destination tensor.
|
|
||||||
epilogue(
|
|
||||||
EpilogueOutputOp(params.output_op),
|
|
||||||
iterator_D,
|
|
||||||
accumulator_tile);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -1224,16 +1165,6 @@ public:
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void operator()()
|
void operator()()
|
||||||
{
|
{
|
||||||
#if (__CUDACC_VER_MAJOR__ > 10)
|
|
||||||
if (params.quick_dp)
|
|
||||||
{
|
|
||||||
// Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray
|
|
||||||
// modes will only be launched using a data-parallel configurations)
|
|
||||||
gemm_dp();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Generic SK code path
|
// Generic SK code path
|
||||||
gemm();
|
gemm();
|
||||||
|
|
||||||
|
|||||||
@ -637,13 +637,6 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
// Device-side interface
|
// Device-side interface
|
||||||
//
|
//
|
||||||
|
|
||||||
/// Proves to the compiler that val is warp-uniform
|
|
||||||
CUTLASS_DEVICE
|
|
||||||
int uniform(int val) const
|
|
||||||
{
|
|
||||||
return __shfl_sync(0xffffffff, val, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Obtains number of threadblocks per GEMM
|
/// Obtains number of threadblocks per GEMM
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
int device_num_blocks() const
|
int device_num_blocks() const
|
||||||
@ -656,7 +649,7 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
int get_sk_tile_idx(int iter) const
|
int get_sk_tile_idx(int iter) const
|
||||||
{
|
{
|
||||||
int tile_idx = div_mod_iters_per_tile.div(iter);
|
int tile_idx = div_mod_iters_per_tile.div(iter);
|
||||||
return uniform(tile_idx);
|
return tile_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Obtains the batch index
|
/// Obtains the batch index
|
||||||
@ -734,7 +727,7 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
block_idx = (region * sk_blocks_per_region()) + block_in_region;
|
block_idx = (region * sk_blocks_per_region()) + block_in_region;
|
||||||
}
|
}
|
||||||
|
|
||||||
return uniform(block_idx);
|
return block_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user