From 91b8de8d3273f73c1de97312b639084baf3269b0 Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Mon, 20 Feb 2023 11:03:16 -0500 Subject: [PATCH] streamk fix (#830) Co-authored-by: Haicheng Wu --- .../gemm/kernel/gemm_universal_streamk.h | 197 ++++++------------ .../threadblock/threadblock_swizzle_streamk.h | 11 +- 2 files changed, 66 insertions(+), 142 deletions(-) diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index 27da66f3..7a722cd6 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -270,8 +270,6 @@ public: ThreadblockSwizzle block_mapping; - bool quick_dp; - void *barrier_workspace; void *partials_workspace; @@ -367,13 +365,6 @@ public: sm_occupancy, device_sms, avail_sms); - - quick_dp = - (block_mapping.sk_waves == 0) && - (mode == GemmUniversalMode::kGemm) && - !block_mapping.cohort_raster && - !EpilogueOutputOp(output_op).is_source_needed(); - } @@ -874,7 +865,7 @@ protected: threadblock_item_begin); // Execute the epilogue operator to update the destination tensor. - epilogue.unified( + epilogue( EpilogueOutputOp(params.output_op), iterator_D, accumulator_tile, @@ -961,13 +952,14 @@ protected: AccumulatorTile accumulator_tile; accumulator_tile.clear(); - // Perform this tile's range of multiply-accumulate (MAC) iterations + // Initialize MMA abstraction Mma mma( shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + // Perform this tile's range of multiply-accumulate (MAC) iterations mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) || @@ -1020,29 +1012,27 @@ protected: void gemm() { // Initialize block's iteration range - int tile_idx, block_iter_begin, block_iters_remaining; + int tile_idx = 0; + int block_iter_begin = 0; + int block_iters_remaining = 0; + + int block_idx = params.block_mapping.get_block_idx(); int sk_padding_start_block_idx = params.block_mapping.sk_regions() * params.block_mapping.sk_blocks_per_region(); int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms; int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks; int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks; - int block_idx = params.block_mapping.get_block_idx(); - if (block_idx < sk_padding_start_block_idx) - { - // This is a SK block - int block_iter_end; - params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end); - block_iters_remaining = block_iter_end - block_iter_begin; + // Initialize tile work descriptor + TileWorkDesc tile_work; - tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1); - } - else if (block_idx < dp_start_block_idx) - { - // This is a filler block - return; - } - else if (block_idx < reduce_start_block_idx) + bool dp_block = (block_idx >= dp_start_block_idx) && (block_idx < reduce_start_block_idx); + bool sk_block = (block_idx < sk_padding_start_block_idx); + bool reduce_block = (block_idx >= reduce_start_block_idx) && + (block_idx < grid_padding_start_block_idx) && + (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed); + + if (dp_block) { // This is a DP block int dp_block_idx = block_idx - dp_start_block_idx; @@ -1058,57 +1048,54 @@ protected: tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms; } - block_iter_begin = 0; block_iters_remaining = params.block_mapping.iters_per_tile() * tile_allottment; - } - else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) && - (block_idx < grid_padding_start_block_idx)) + init_dp_tile_work(tile_work, tile_idx); + + // DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1) + if ((tile_idx < params.block_mapping.sk_tiles) || + (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) || + (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n())) + { + return; + } + } + else if (sk_block) { - // This is a reduction threadblock - int reduce_block_idx = block_idx - reduce_start_block_idx; - separate_reduction(reduce_block_idx); - return; + // This is a SK block + int block_iter_end; + params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end); + block_iters_remaining = block_iter_end - block_iter_begin; + + tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1); + + init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); } else { - // This is a filler block + if (reduce_block) + { + // This is a reduction threadblock + int reduce_block_idx = block_idx - reduce_start_block_idx; + separate_reduction(reduce_block_idx); + } + return; } + // Perform this block's share of work for this tile + process_tile( + tile_work, + block_idx, + dp_start_block_idx, + block_iter_begin); + + block_iters_remaining -= tile_work.k_iters_remaining; + // Iteration-processing loop body CUTLASS_PRAGMA_NO_UNROLL - while (true) + while (block_iters_remaining != 0) { - // Initialize tile work descriptor - TileWorkDesc tile_work; - if (block_idx >= dp_start_block_idx) - { - init_dp_tile_work(tile_work, tile_idx); - - // DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1) - if ((tile_idx < params.block_mapping.sk_tiles) || - (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape().m()) || - (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape().n())) - { - break; - } - } - else - { - init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); - } - - // Perform this block's share of work for this tile - process_tile(tile_work, block_idx, dp_start_block_idx, block_iter_begin); - - // Update remaining work for this block - block_iters_remaining -= tile_work.k_iters_remaining; - if (block_iters_remaining == 0) { - // Done - break; - } - // Continue to next tile __syncthreads(); @@ -1116,74 +1103,28 @@ protected: { // DP block consume their tiles at stride tile_idx += params.block_mapping.avail_sms; + init_dp_tile_work(tile_work, tile_idx); } else { // SK blocks consume their tiles in backwards order tile_idx--; + init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); } + + // Perform this block's share of work for this tile + process_tile( + tile_work, + block_idx, + dp_start_block_idx, + block_iter_begin); + + block_iters_remaining -= tile_work.k_iters_remaining; } } - /// Executes one DP-only GEMM - CUTLASS_DEVICE - void gemm_dp() - { - int block_idx = blockIdx.x; - int tile_idx = block_idx; - - TileWorkDesc tile_work; - tile_work.tile_idx = tile_idx; - tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile(); - tile_work.k_iters_remaining = params.block_mapping.iters_per_tile(); - tile_work.k_begin = 0; - tile_work.k_end = params.block_mapping.problem_size.k(); - tile_work.tiled_coord = params.block_mapping.get_tile_offset_row_major(tile_work.tile_idx); - - // Initialize input iterators - typename Mma::IteratorA iterator_A = init_iterator_A(tile_work, params.mode); - typename Mma::IteratorB iterator_B = init_iterator_B(tile_work, params.mode); - - // Initialize accumulators - AccumulatorTile accumulator_tile; - accumulator_tile.clear(); - - // Perform this tile's range of multiply-accumulate (MAC) iterations - Mma mma( - shared_storage.main_loop, - thread_idx, - warp_idx, - lane_idx); - - mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); - - ElementC *ptr_D = static_cast(params.ptr_D); - - // Location of this tile in item-coords - MatrixCoord threadblock_item_begin( - tile_work.tiled_coord.m() * Mma::Shape::kM, - tile_work.tiled_coord.n() * Mma::Shape::kN - ); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - ptr_D, - params.block_mapping.problem_size.mn(), - thread_idx, - threadblock_item_begin); - - // Execute the epilogue operator to update the destination tensor. - epilogue( - EpilogueOutputOp(params.output_op), - iterator_D, - accumulator_tile); - } - - - public: // @@ -1224,16 +1165,6 @@ public: CUTLASS_DEVICE void operator()() { -#if (__CUDACC_VER_MAJOR__ > 10) - if (params.quick_dp) - { - // Simple (low-bootstrap latency) GEMM code path for data-parallel only. (kBatched and kArray - // modes will only be launched using a data-parallel configurations) - gemm_dp(); - return; - } -#endif - // Generic SK code path gemm(); diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h index b91046e5..239ced7a 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -637,13 +637,6 @@ struct ThreadblockSwizzleStreamK { // Device-side interface // - /// Proves to the compiler that val is warp-uniform - CUTLASS_DEVICE - int uniform(int val) const - { - return __shfl_sync(0xffffffff, val, 0); - } - /// Obtains number of threadblocks per GEMM CUTLASS_DEVICE int device_num_blocks() const @@ -656,7 +649,7 @@ struct ThreadblockSwizzleStreamK { int get_sk_tile_idx(int iter) const { int tile_idx = div_mod_iters_per_tile.div(iter); - return uniform(tile_idx); + return tile_idx; } /// Obtains the batch index @@ -734,7 +727,7 @@ struct ThreadblockSwizzleStreamK { block_idx = (region * sk_blocks_per_region()) + block_in_region; } - return uniform(block_idx); + return block_idx; }