cutlass/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h
Haicheng Wu 1e64f153b3
improve streamk load balance (#743)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2022-12-25 13:56:33 -05:00

771 lines
25 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements streamk threadblock mapping blockIdx to GEMM problems.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/conv3d_problem_size.h"
#include "cutlass/gemm/threadblock/index_remat.h"
#include <iostream>
#include "cutlass/core_io.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Threadblock mapping control for GEMMs
struct ThreadblockSwizzleStreamK {
/// Advertise StreamkFeature
using StreamkFeature = void;
/// Kernel traits
template <typename GemmKernel>
struct KernelTraits {};
/// Reduction strategy
enum ReductionStrategy
{
kNone, // Data-parallel strategy (no seams, fixup, etc.)
kAtomic, // Non-deterministic reduction of SK-block partials using atomic aggregation in L2
kMixed, // Deterministic reduction of SK-block partials employing either:
// (a) A separate wave of reduction thread blocks" (for scenarios with lots of
// SK-blocks per SK-tile)
// (b) Turnstile-ordered atomic aggregation in L2 (for scenarios with few
// SK-blocks per SK-tile)
};
static ReductionStrategy const kReductionStrategy = kMixed;
//
// Heuristics
//
/// Data-parallel wave-quantization efficiency threshold (above which we go data-parallel)
static float constexpr kDpEfficiencyThreshold = 0.92f;
/// Minimum number of MAC-iterations per streamk block
static int const kMinItersPerSkBlock = 2;
/// Height in CTAs of a grid rasterization cohort
static int const kCohortCtasM = 8;
/// Width in CTAs of a grid rasterization cohort
static int const kCohortCtasN = 4;
/// Number of CTAs per cohort
static int const kCtasPerCohort = kCohortCtasN * kCohortCtasM;
/// Cost-equivalent number of SM-iterations for fixup I/O
static int const kFixupStartupIterEquiv = 10;
static int const kFixupPeerIterEquiv = 3;
//
// Member state
//
/// The 3D value-extents of the GEMM computation volume (m,n,k)
GemmCoord problem_size;
/// The 2D tile-extents of the output matrix (m,n)
GemmCoord tiled_shape;
/// Number of iterations per output tile
int iters_per_tile;
/// Number of reduction blocks in the grid
int reduction_blocks;
int dp_blocks; /// Number of data-parallel thread blocks in the grid
int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce
int sk_tiles;
int sk_regions;
int sk_blocks_per_region;
int sk_big_blocks_per_region;
int sk_iters_per_region;
int sk_iters_per_normal_block; /// Number of iterations for normal SK-blocks
int sk_waves; /// Number of SK waves in the grid
/// CTA occupancy per SM
int sm_occupancy;
/// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size)
int avail_sms;
/// Whether to perform cohort CTA rasterization
bool cohort_raster;
/// Div/mod accelerators
struct
{
FastDivmod tiled_shape_m;
FastDivmod tiled_shape_n;
FastDivmod tiled_cohort_shape_n;
FastDivmod iters_per_tile;
FastDivmod sk_iters_per_normal_block;
FastDivmod sk_iters_per_big_block;
FastDivmod sk_iters_per_region;
FastDivmod sk_blocks_per_region;
} div_mod;
//
// Host+device interface
//
/// Constructor
CUTLASS_HOST_DEVICE
ThreadblockSwizzleStreamK() {}
//
// Host-side interface
//
/// Debug print
void Print()
{
#ifndef __CUDA_ARCH__
int tiles = tiled_shape.m() * tiled_shape.n();
std::cout <<
"problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" <<
", reduction_blocks: " << reduction_blocks <<
", dp_blocks: " << dp_blocks <<
", sk_blocks_per_region: " << sk_blocks_per_region <<
", sk_regions: " << sk_regions <<
", sk_waves: " << sk_waves <<
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
", tiled_shape: (" << tiled_shape.m() << "," << tiled_shape.n() << ")" <<
", tiles: " << tiles <<
", iters_per_tile: " << iters_per_tile <<
", dp_tiles: " << tiles - sk_tiles <<
", sk_tiles: " << sk_tiles <<
", avail_sms: " << avail_sms <<
", sm_occupancy: " << sm_occupancy <<
", avail_sms: " << avail_sms <<
", cohort_raster: " << cohort_raster <<
", num_blocks: " << get_num_blocks() <<
"\n\n";
#endif
}
// Compute sk_blocks to dispatch for a given number of sk_tiles
static void get_sk_blocks(
int &sk_blocks, /// [out]
int &savings_iters, /// [out]
int sk_tiles,
int iters_per_tile,
int avail_sms,
int max_sk_occupancy,
bool allow_partial_wave)
{
savings_iters = INT_MIN;
sk_blocks = 0;
if (sk_tiles == 0) {
return;
}
int sk_iters = sk_tiles * iters_per_tile;
int dp_equiv_waves = (sk_tiles + avail_sms - 1) / avail_sms;
int dp_equiv_iters = iters_per_tile * dp_equiv_waves;
int min_sk_blocks = (allow_partial_wave) ? fast_min(avail_sms, sk_tiles + 1) : avail_sms;
int max_sk_blocks = fast_min(avail_sms * max_sk_occupancy, sk_iters / kMinItersPerSkBlock);
for (int trial_sk_blocks = min_sk_blocks; trial_sk_blocks <= max_sk_blocks; ++trial_sk_blocks)
{
int sk_waves = (trial_sk_blocks + avail_sms - 1) / avail_sms;
int max_sk_iters_per_block = (sk_iters + trial_sk_blocks - 1) / trial_sk_blocks;
int sk_iter_equiv = max_sk_iters_per_block * sk_waves;
int num_peers = ((trial_sk_blocks + sk_tiles - 1) / sk_tiles) + 1; // add one for alignment skew
float iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv);
if (trial_sk_blocks % sk_tiles == 0)
{
// aligned
num_peers = (trial_sk_blocks / sk_tiles);
iter_cost = 0.0f;
}
float peer_cost = 2.0f * float(num_peers);
float base_cost = 2.0f * float(sk_waves);
int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost);
int trial_savings_iters = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv;
if (trial_savings_iters >= savings_iters) {
savings_iters = trial_savings_iters;
sk_blocks = trial_sk_blocks;
}
}
}
/// Determine the populations of DP and SK blocks to invoke for the given number of output tiles
static void get_blocks(
int &dp_tiles, /// [out]
int &sk_blocks, /// [out]
int output_tiles,
int iters_per_tile,
int avail_sms,
int sm_occupancy)
{
int full_waves = output_tiles / avail_sms;
int full_wave_tiles = full_waves * avail_sms;
int partial_wave_tiles = output_tiles - full_wave_tiles;
int score = -1;
dp_tiles = output_tiles;
sk_blocks = 0;
if (partial_wave_tiles == 0)
{
// Perfect quantization
return;
}
if (full_waves < sm_occupancy)
{
// We're less than full GPU occupancy
// Form the SK wave from the partial wave to get us up to full GPU occupancy
int max_sk_occupancy = sm_occupancy - full_waves;
dp_tiles = full_wave_tiles;
get_sk_blocks(
sk_blocks,
score,
partial_wave_tiles,
iters_per_tile,
avail_sms,
max_sk_occupancy,
true); // we can run with less than a full wave of SK-blocks
if (score < 0) {
// not profitable
sk_blocks = 0;
dp_tiles = output_tiles;
}
return;
}
// We're at (or greater) than GPU occupancy
if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1))
{
// If occupancy is more than one CTA per SM, form the SK wave from the partial
// wave to get us to full GPU occupancy
int max_sk_occupancy = 1;
dp_tiles = full_wave_tiles;
get_sk_blocks(
sk_blocks,
score,
partial_wave_tiles,
iters_per_tile,
avail_sms,
max_sk_occupancy,
true); // we can run with less than a full wave of SK-blocks
if (score >= 0) {
return;
}
}
// Form the SK wave by combining the last full wave and the partial wave
// We're less than full GPU occupancy
dp_tiles = full_wave_tiles - avail_sms;
int max_sk_occupancy = sm_occupancy - ((full_waves - 1) % sm_occupancy);
get_sk_blocks(
sk_blocks,
score,
partial_wave_tiles + avail_sms,
iters_per_tile,
avail_sms,
max_sk_occupancy,
false); // we cannot run with less than a full wave of SK-blocks
if (score < 0) {
// not profitable
sk_blocks = 0;
dp_tiles = output_tiles;
}
}
/// Constructor: *Gemm* problem size (m, n, k)
template <typename GemmKernel>
ThreadblockSwizzleStreamK(
KernelTraits<GemmKernel> const kernel_traits_,
GemmUniversalMode const mode_,
GemmCoord const problem_size_,
GemmCoord const tile_size_,
int const batch_count_, /// Batch count (when mode_ == GemmUniversalMode::kBatched) or split-K-override splitting factor (when mode_ == GemmUniversalMode::kGemm)
int const sm_occupancy_,
int const avail_sms_)
:
problem_size(problem_size_),
tiled_shape(
(problem_size.m() + tile_size_.m() - 1) / tile_size_.m(),
(problem_size.n() + tile_size_.n() - 1) / tile_size_.n(),
(mode_ == GemmUniversalMode::kBatched) ? batch_count_ : 1),
iters_per_tile((problem_size.k() + tile_size_.k() - 1) / tile_size_.k()),
reduction_blocks(0),
dp_blocks(0),
dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks
sk_tiles(0),
sk_regions(1), // Default: a single region of iteration space (across all SK tiles)
sk_blocks_per_region(0),
sk_big_blocks_per_region(0),
sk_iters_per_region(0),
sk_iters_per_normal_block(0),
sk_waves(0),
sm_occupancy(sm_occupancy_),
avail_sms(fast_max(1, avail_sms_)),
cohort_raster(false)
{
size_t problem_bytes =
(sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) +
(sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) +
(sizeof(typename GemmKernel::ElementB) * problem_size.k() * problem_size.n());
size_t problem_flops = size_t(problem_size.m()) * size_t(problem_size.n()) * size_t(problem_size.k()) * 2;
float flops_per_byte = float(problem_flops) / float(problem_bytes);
int gpu_occupancy = avail_sms * sm_occupancy;
int output_tiles = tiled_shape.m() * tiled_shape.n();
int waves = (output_tiles + avail_sms - 1) / avail_sms;
float dp_efficiency = float(output_tiles) / float(waves * avail_sms);
//
// Determine dispatch composition of DP-tiles and SK-blocks
//
// Start with a DP-only configuration
int dp_tiles = output_tiles; // Number of data-parallel tiles
int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles
// kGemm mode allows for SK load balancing
if (mode_ == GemmUniversalMode::kGemm)
{
if (batch_count_ > 1)
{
// Split-K override
dp_tiles = 0;
sk_blocks = output_tiles * batch_count_;
}
else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled
(avail_sms > 1)) // Plurality of SMs to load balance across
{
// Use heuristics
get_blocks(
dp_tiles, /// [out]
sk_blocks, /// [out]
output_tiles,
iters_per_tile,
avail_sms,
sm_occupancy);
}
}
sk_tiles = output_tiles - dp_tiles;
// Compute SK block iteration details
if (sk_blocks > 0)
{
sk_waves = (sk_blocks + avail_sms - 1) / avail_sms;
int sk_iters = sk_tiles * iters_per_tile;
sk_blocks = fast_min(sk_blocks, sk_iters);
sk_iters_per_normal_block = sk_iters / sk_blocks;
int extra_sk_iters = sk_iters - (sk_iters_per_normal_block * sk_blocks);
int sk_big_blocks = extra_sk_iters;
if ((sk_blocks > sk_tiles) && (sk_blocks % sk_tiles == 0))
{
// Split-K decomposition
sk_regions = sk_tiles;
}
sk_blocks_per_region = sk_blocks / sk_regions;
sk_big_blocks_per_region = sk_big_blocks / sk_regions;
sk_iters_per_region = sk_iters / sk_regions;
div_mod.sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block);
div_mod.sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1);
div_mod.sk_iters_per_region = FastDivmod(sk_iters_per_region);
div_mod.sk_blocks_per_region = FastDivmod(sk_blocks_per_region);
// Separate reduction heuristic
if ((kReductionStrategy == kMixed) &&
(sk_blocks > 2 * sk_tiles)) // Use a separate reduction wave whenever we would have more than three
// peers working on an SK tile. (This occurs when the ratio of SK-blocks
// to SK-tiles > 2, as a single tile may be covered by four SK-blocks,
// e.g.:[partial-block | block | block | partial-block] ). With three or
// less peers, the two non-finishing SK-blocks are not expexted to contend.
{
// Launch a reduction block every accumulator fragment in each SK-tile
static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments;
reduction_blocks = sk_tiles * kAccumulatorFragments;
}
}
//
// Compute DP blocks
//
dp_blocks = dp_tiles;
cutlass::gemm::GemmCoord tiled_cohort_shape(
(tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM,
(tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN,
batch_count_);
int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort;
float cohort_efficiency = float(dp_blocks) / float(cohort_blocks);
// Check if the SK tiles would be in cohorts that are in-bounds
bool sk_in_range = true;
if (sk_tiles > 0)
{
int last_sk_tile = sk_tiles - 1;
int cohort_tile_idx = last_sk_tile / kCtasPerCohort;
int cohort_grid_m = cohort_tile_idx / tiled_cohort_shape.n();
int cohort_grid_n = (cohort_grid_m > 0) ?
tiled_cohort_shape.n() - 1 :
cohort_tile_idx % tiled_cohort_shape.n();
if ((((cohort_grid_m + 1) * kCohortCtasM) >= tiled_shape.m()) ||
(((cohort_grid_n + 1) * kCohortCtasN) >= tiled_shape.n()))
{
sk_in_range = false;
}
}
// Decide if we're going to be doing cohort raster
if (sk_in_range &&
(dp_blocks >= gpu_occupancy) &&
(cohort_efficiency > 0.85f))
{
cohort_raster = true;
dp_blocks = cohort_blocks;
}
else if (sk_waves > 0)
{
// Update semi-persistence of first DP wave to ensure full grid wavesets
// (Only applies when there's an SK component and we're not doing blocked cohort rasterization)
int dp_tile_waves = (dp_tiles + avail_sms - 1) / avail_sms;
int full_dp_tile_waves = dp_tiles / avail_sms;
int waveset_excess = (sk_waves + dp_tile_waves) % sm_occupancy;
if (dp_first_wave_tiles + waveset_excess <= full_dp_tile_waves)
{
dp_first_wave_tiles += waveset_excess;
dp_blocks -= (waveset_excess * avail_sms);
}
}
// Setup fast-div/mod for device-side usage
div_mod.tiled_shape_m = FastDivmod(tiled_shape.m());
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
}
/// Constructor: *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC)
template <typename GemmKernel>
ThreadblockSwizzleStreamK(
KernelTraits<GemmKernel> kernel_traits_,
GemmUniversalMode mode_,
cutlass::conv::Operator conv_operator,
cutlass::conv::Conv2dProblemSize const &problem_size_,
GemmCoord tile_size_,
int batch_count_,
int sm_occupancy_,
int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs
int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles
:
ThreadblockSwizzleStreamK(
kernel_traits_,
mode_,
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_),
tile_size_,
batch_count_,
sm_occupancy_,
avail_sms_,
dp_tiles_,
sk_blocks_)
{}
/// Constructor: *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC)
template <typename GemmKernel>
ThreadblockSwizzleStreamK(
KernelTraits<GemmKernel> kernel_traits_,
GemmUniversalMode mode_,
cutlass::conv::Operator conv_operator,
cutlass::conv::Conv3dProblemSize const &problem_size_,
GemmCoord tile_size_,
int batch_count_,
int sm_occupancy_,
int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance
int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs
int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles
:
ThreadblockSwizzleStreamK(
kernel_traits_,
mode_,
cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_),
tile_size_,
batch_count_,
sm_occupancy_,
avail_sms_,
dp_tiles_,
sk_blocks_)
{}
/// Obtains number of threadblocks per GEMM
int get_num_blocks() const
{
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
if (work_blocks <= avail_sms * 2)
{
return work_blocks;
}
return fast_max(work_blocks, avail_sms * 4);
}
/// Obtains grid extents in CTAs
dim3 get_grid_dims() const
{
return dim3(get_num_blocks(), 1, tiled_shape.k());
}
//
// Device-side interface
//
/// Obtains number of threadblocks per GEMM
CUTLASS_DEVICE
int device_num_blocks() const
{
return gridDim.x;
}
/// Obtains tile index for the given sk iteration
CUTLASS_DEVICE
int get_sk_tile_idx(int iter) const
{
return div_mod.iters_per_tile.div(iter);
}
/// Obtains the calling threadblock's tiled coordinates for the given tile index
CUTLASS_DEVICE
GemmCoord get_tile_offset(int tile_idx) const
{
int m, n;
if (cohort_raster)
{
// tiled cohort raster
int cohort_tile_idx = tile_idx / kCtasPerCohort;
int cohort_grid_m, cohort_grid_n;
div_mod.tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx);
int block_idx_cohort = tile_idx % kCtasPerCohort;
int block_cohort_m = block_idx_cohort / kCohortCtasN;
int block_cohort_n = block_idx_cohort % kCohortCtasN;
m = (cohort_grid_m * kCohortCtasM) + block_cohort_m;
n = (cohort_grid_n * kCohortCtasN) + block_cohort_n;
}
else if (tiled_shape.m() < tiled_shape.n())
{
// column-major raster
div_mod.tiled_shape_m(n, m, tile_idx);
}
else
{
// row-major raster
div_mod.tiled_shape_n(m, n, tile_idx);
}
int block_idx_k = RematerializeBlockIdxZ();
return GemmCoord{m, n, block_idx_k};
}
/// Obtains calling threadblock's linear threadblock index
CUTLASS_DEVICE
int get_block_idx() const
{
// Remap the block indices for the first two waves of thread blocks if
// we have multi-occupancy and the grid constitutes four or more waves
int block_idx = RematerializeBlockIdxX();
int num_blocks = device_num_blocks();
int dest_sm = block_idx / 2;
int dest_wave = block_idx % 2;
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
if ((sm_occupancy > 1) &&
(num_blocks >= avail_sms * 4) &&
(block_idx < avail_sms * 2))
{
block_idx = remapped_block_idx;
}
// Block-index is blockIdx.x for DP blocks
return block_idx;
}
/// Obtains calling linear threadblock index of the first block to work on the given tile
CUTLASS_DEVICE
int get_sk_block_idx(int iter) const
{
int region_idx;
int iter_in_region;
div_mod.sk_iters_per_region(region_idx, iter_in_region, iter);
int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block) + sk_big_blocks_per_region; // number of iterations in the region's big blocks
int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal bocks
int big_block_idx_in_region = div_mod.sk_iters_per_big_block.div(iter_in_region);
int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod.sk_iters_per_normal_block.div(normal_block_iters);
int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ?
big_block_idx_in_region :
normal_block_idx_in_region;
return (sk_blocks_per_region * region_idx) + block_idx_in_region;
}
/// Obtains iteration extends for the given SK block index
CUTLASS_DEVICE
void get_iter_extents(
int sk_block_idx,
int &block_iter_begin,
int &block_iter_end) const
{
int region_idx;
int block_idx_in_region;
div_mod.sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx);
block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block);
// Adjust extents for the first "num_big_blocks" blocks that get one extra iteration
int block_iters = sk_iters_per_normal_block;
if (block_idx_in_region < sk_big_blocks_per_region) {
// This is a +1 iteration block
block_iter_begin += block_idx_in_region;
block_iters++;
} else {
// This is a regular block
block_iter_begin += sk_big_blocks_per_region;
}
block_iter_end = block_iter_begin + block_iters;
}
/// Obtains calling linear threadblock index of the first block to work on the given tile
CUTLASS_DEVICE
int get_first_block_idx(int tile_idx, int block_idx) const
{
if (tile_idx >= sk_tiles) {
// DP tile
return block_idx;
}
int iter = tile_idx * iters_per_tile;
return get_sk_block_idx(iter);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass