cutlass/include/cutlass/gemm/kernel/tile_scheduler_params.h
ANIKET SHIVAM bbe579a9e3
Updates for CUTLASS 3.4.1 (#1346)
* Updates for CUTLASS 3.4.1

* minor epi change
2024-02-15 15:48:34 -05:00

1502 lines
57 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/*! \file
\brief Parameters structures for persistent tile schedulers
*/
/*
Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain
existing integrations of CUTLASS require C++11 host compilers.
Until this requirement can be lifted, certain headers with this annotation are required
to be remain consistent with C++11 syntax.
C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`.
*/
#include "cutlass/coord.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/workspace.h"
#include "cutlass/platform/platform.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm_coord.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
namespace detail {
////////////////////////////////////////////////////////////////////////////////
//
// Parameters for SM90 tile schedulers
//
// Parameters for SM90 persistent tile scheduler
struct PersistentTileSchedulerSm90Params {
enum class RasterOrder {
AlongM,
AlongN
};
enum class RasterOrderOptions {
Heuristic,
AlongM,
AlongN
};
FastDivmodU64Pow2 divmod_cluster_shape_major_{};
FastDivmodU64Pow2 divmod_cluster_shape_minor_{};
FastDivmodU64 divmod_batch_{};
FastDivmodU64 divmod_cluster_blk_major_{};
uint64_t blocks_per_problem_ = 0;
int32_t log_swizzle_size_ = 0;
RasterOrder raster_order_ = RasterOrder::AlongN;
uint32_t problem_tiles_m_ = 0;
uint32_t problem_tiles_n_ = 0;
uint32_t problem_tiles_l_ = 0;
uint32_t cluster_shape_m_ = 0;
uint32_t cluster_shape_n_ = 0;
// Initializes members. This variant of the method should only be used when
// problem_shape and tile_shape contain modes of only rank 1.
void
initialize(
BatchedGemmCoord problem_shape,
GemmCoord tile_shape,
GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info,
int max_swizzle_size,
RasterOrderOptions raster_order_option
) {
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape);
return initialize(
problem_blocks,
cluster_shape,
hw_info,
max_swizzle_size,
raster_order_option
);
}
// Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
void
initialize(
dim3 problem_blocks,
GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info,
int max_swizzle_size,
RasterOrderOptions raster_order_option
) {
CUTLASS_UNUSED(hw_info);
// Round up to nearest multiple of swizzle_size along each mode
auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size);
auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m());
auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n());
problem_tiles_m_ = problem_blocks_m / cluster_shape.m();
problem_tiles_n_ = problem_blocks_n / cluster_shape.n();
problem_tiles_l_ = problem_blocks.z;
cluster_shape_m_ = cluster_shape.m();
cluster_shape_n_ = cluster_shape.n();
RasterOrder raster_order = get_rasterization_order(
problem_blocks_m,
problem_blocks_n,
raster_order_option
);
//
// Set members
//
blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks.z;
log_swizzle_size_ = log_swizzle_size;
raster_order_ = raster_order;
divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n);
if (raster_order == RasterOrder::AlongN) {
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n());
divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.m());
divmod_cluster_blk_major_ = FastDivmodU64(problem_blocks_n / cluster_shape.n());
}
else {
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m());
divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n());
divmod_cluster_blk_major_ = FastDivmodU64(problem_blocks_m / cluster_shape.m());
}
}
// Given the inputs, computes the physical grid we should launch.
// This variant of the method should only be used when
// problem_shape and tile_shape contain modes of only rank 1.
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
BatchedGemmCoord problem_shape,
GemmCoord cta_shape,
GemmCoord cluster_shape,
KernelHardwareInfo hw_info,
int max_swizzle_size,
RasterOrderOptions raster_order_option,
bool truncate_by_problem_size=true) {
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape);
return get_grid_shape(
problem_blocks,
cluster_shape,
hw_info,
max_swizzle_size,
raster_order_option,
truncate_by_problem_size
);
}
// Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
dim3 problem_blocks,
GemmCoord cluster_shape,
KernelHardwareInfo hw_info,
int max_swizzle_size,
RasterOrderOptions raster_order_option,
bool truncate_by_problem_size=true) {
int const sm_count = hw_info.sm_count;
// Round up to nearest multiple of swizzle_size along each mode
auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size);
auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m());
auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n());
int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z;
RasterOrder raster_order = get_rasterization_order(
problem_blocks_m,
problem_blocks_n,
raster_order_option
);
dim3 launch_grid;
if (raster_order == RasterOrder::AlongN) {
launch_grid = dim3(cluster_shape.m(), 1, 1);
}
else {
launch_grid = dim3(1, cluster_shape.n(), 1);
}
auto possibly_truncate = [&](int x, int y) {
if (truncate_by_problem_size) {
return platform::min(x, y);
}
else {
return x;
}
};
// The else path is generic, however, we can avoid some divs if we know cluster size is 1
auto cluster_size = cluster_shape.m() * cluster_shape.n();
if (cluster_size == 1) {
if (raster_order == RasterOrder::AlongN) {
launch_grid.y = possibly_truncate(sm_count, problem_blocks_total);
}
else {
launch_grid.x = possibly_truncate(sm_count, problem_blocks_total);
}
}
else {
/*
* Optimal grid size calculation is based on
* GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU
* Hence, maximum SMs per GPC = 18
*/
constexpr int max_sm_per_gpc = 18;
// Provided SM count could possibly be less than the assumed maximum SMs per GPC
auto cluster_size = cluster_shape.m() * cluster_shape.n();
int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc;
int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size);
int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc;
// The calculation below allows for larger grid size launch for different GPUs.
int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc;
int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size);
cta_per_device += max_cta_occupancy_per_residual_gpc;
cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device;
if (raster_order == RasterOrder::AlongN) {
launch_grid.y = possibly_truncate(
cta_per_device / cluster_shape.m(),
problem_blocks_total / cluster_shape.m());
}
else {
launch_grid.x = possibly_truncate(
cta_per_device / cluster_shape.n(),
problem_blocks_total / cluster_shape.n());
}
}
return launch_grid;
}
CUTLASS_HOST_DEVICE
static int32_t
get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) {
int min_cta_dim = platform::min(problem_ctas_m, problem_ctas_n);
if (max_swizzle_size >= 8 && min_cta_dim >= 6) {
return 3;
}
else if (max_swizzle_size >= 4 && min_cta_dim >= 3) {
return 2;
}
else if (max_swizzle_size >= 2 && min_cta_dim >= 2) {
return 1;
}
else {
return 0;
}
}
CUTLASS_HOST_DEVICE
static RasterOrder
get_rasterization_order(
uint32_t tiles_m,
uint32_t tiles_n,
RasterOrderOptions raster_order_option
) {
if (raster_order_option == RasterOrderOptions::Heuristic) {
if (tiles_n > tiles_m) {
return RasterOrder::AlongM;
}
else {
return RasterOrder::AlongN;
}
}
else {
switch (raster_order_option) {
case RasterOrderOptions::AlongN:
return RasterOrder::AlongN;
break;
default:
return RasterOrder::AlongM;
}
}
}
// Get the number of CTA tiles in this problem. This variant of the method should only be used when
// problem_shape and tile_shape contain modes of only rank 1.
CUTLASS_HOST_DEVICE
static dim3
get_tiled_cta_shape_mnl(BatchedGemmCoord problem_shape, GemmCoord cta_shape, GemmCoord cluster_shape) {
auto cta_m = (problem_shape.m() + cta_shape.m() - 1) / cta_shape.m();
auto cta_n = (problem_shape.n() + cta_shape.n() - 1) / cta_shape.n();
return get_tiled_cta_shape_mnl(problem_shape, cluster_shape, cta_m, cta_n);
}
// Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
CUTLASS_HOST_DEVICE
static dim3
get_tiled_cta_shape_mnl(BatchedGemmCoord problem_shape, GemmCoord cluster_shape, uint32_t cta_m, uint32_t cta_n) {
// Round up to nearest multiple of cluster dim along each mode
auto problem_blocks_m = ((cta_m + cluster_shape.m() - 1) / cluster_shape.m()) * cluster_shape.m();
auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n();
return {
static_cast<uint32_t>(problem_blocks_m),
static_cast<uint32_t>(problem_blocks_n),
static_cast<uint32_t>(problem_shape.batch())
};
}
};
////////////////////////////////////////////////////////////////////////////////
// Parameters for SM90 persistent stream-K scheduler
struct PersistentTileSchedulerSm90StreamKParams {
// Strategies for computing reductions between CTAs computing portions of a given output tile
enum class ReductionMode {
// Participating CTAs perform reduction in a turnstile fashion in order of the K extent
// covered by each CTA. This requires a lock to be held exclusively be the CTA that is
// currently accumulating.
//
// Turnstile accumulation ensures deterministic numeric behavior when using this mode.
Deterministic,
// Participating CTAs perform reduction atomically to the same workspace (mostly) without locking.
// Locks are used only to wait for the first CTA to write its partial values (to initialize the
// workspace), and for all but the final CTA to have accumulated (so that the final CTA can load
// the accumulated value and accumulate it into registers on top of which the epilogue will
// be performed).
//
// Due to the nondeterminsitic ordering of accumulation, deterministic numeric behavior cannot
// be guaranteed with this mode (e.g., floating-point rounding error will depend on the order
// of accumulation)
Nondeterministic
};
// Strategies for decomposing the problem
enum class DecompositionMode {
// Use a heuristic to determine whether data-parallel, split-K, or stream-K decomposition should be performed
Heuristic,
// Force a data-parallel decomposition
DataParallel,
// Force a split-K decomposition. This should be paired with setting the `splits` parameter
SplitK,
// Force a stream-K decomposition
StreamK
};
using UnderlyingParams = PersistentTileSchedulerSm90Params;
using RasterOrder = UnderlyingParams::RasterOrder;
using RasterOrderOptions = UnderlyingParams::RasterOrderOptions;
// Cluster dimensions are typically always a power of 2, so use
// the power-of-two variants of FastDivmod for these.
FastDivmodU64Pow2 divmod_cluster_shape_major_{};
FastDivmodU64Pow2 divmod_cluster_shape_minor_{};
FastDivmodU64 divmod_batch_{};
FastDivmodU64 divmod_cluster_blk_major_{};
// Total number of cluster-sized output tiles (i.e., not including any
// splitting factors). This is primarily used for split-K decompositions,
// and may be overridden in other decompositions.
FastDivmodU64 divmod_clusters_mnl_{};
// We divide up the number of stream-K tiles amongst G groups of stream-K units.
// The stream-K units within a group collaborate to comptue over the `sk_tiles / G`
// tiles assigned to that group. Non-unit group sizes can help to preserve L2 locality of
// partial chunks computed by stream-K units -- units 0 in each group will compute identical K extents
// of tiles that would be assigned in the same wave according to the rasterization order of the
// data-parallel formulation of the problem.
FastDivmodU64 divmod_sk_groups_{};
// Number of stream-K units in each group
FastDivmodU64 divmod_sk_units_per_group_{};
uint64_t units_per_problem_ = 0;
FastDivmod divmod_tiles_per_output_tile_{};
int32_t log_swizzle_size_ = 0;
RasterOrder raster_order_ = RasterOrder::AlongN;
// The splitting factor to be used in a split-K decomposition of the problem.
// If this is set to a value greater than 1, stream-K decomposition logic
// is bypassed in favor of a split-K decomposition.
uint32_t splits_ = 1;
// Number of stream-K or split-K work units that compute an extra k iteration.
// This is done to handle residuals in dividing up the k iteration space.
// For stream-K, since the actual assignment of work to stream-K units will be done
// at the granularity of a cluster, we store only the number of big clusters.
uint32_t big_units_ = 0;
// The number of groups of stream-K units that will process an extra stream-K tile cluster.
uint32_t big_groups_ = 0;
// Workspace for holding partial accumulators to be reduced across stream-K/split-K units
void* reduction_workspace_ = nullptr;
// Number of tiles covered by stream-K work units
uint32_t sk_tiles_ = 0;
// Number of work units computing stream-K tiles
uint32_t sk_units_ = 0;
// Number of tiled k iterations computed by each stream-K work unit. This
// can potentially cover more than one output tile.
uint32_t k_tiles_per_sk_unit_ = 0;
// Strategy to use when reducing between collaborating CTAs
ReductionMode reduction_mode_ = ReductionMode::Deterministic;
// The number of sub blocks in the kernel epilogue
FastDivmodU64 divmod_epilogue_subtile_{};
// The number of blocks that launched for doing separate reduction
uint32_t separate_reduction_units_ = 0;
// Minimum number of k tiles that can be assigned to a stream-K unit
static constexpr uint32_t min_iters_per_sk_unit_ = 8u;
// Maximum number of groups of stream-K units
static constexpr uint32_t max_sk_groups_ = 8u;
// Divides dividend by the cluster size
CUTLASS_HOST_DEVICE
uint64_t
div_cluster_size(uint64_t dividend) const {
// Use each underlying fast divmod rather than performing integer division
// by the multiplication of major.divisor * minor.divisor
return divmod_cluster_shape_minor_.divide(
divmod_cluster_shape_major_.divide(dividend)
);
}
CUTLASS_HOST_DEVICE
uint64_t
get_cluster_size() const {
return divmod_cluster_shape_minor_.divisor * divmod_cluster_shape_major_.divisor;
}
// Returns whether the kernel uses separate reduction
CUTLASS_HOST_DEVICE
bool
requires_separate_reduction() const {
return separate_reduction_units_ > 0;
}
// Returns the maximum number of peers that can collaborate on a given output tile
CUTLASS_HOST_DEVICE
static uint32_t
max_peers_per_tile(uint64_t sk_units, uint64_t sk_tiles) {
// When we can divide up our SK units to SK tiles evenly, the number of peers
// per SK tile is exactly (sk_units_ / sk_tiles_). In cases where this division
// is not exact, some tiles will need to be covered by additional SK units. Because
// the extra work can occur at both the beginning and the end of the SK tile, at
// most 2 extra peers will be needed.
return static_cast<uint32_t>(sk_units / sk_tiles + 2);
}
// Initializes members. This variant of the method should only be used when
// problem_shape and tile_shape contain modes of only rank 1.
void
initialize(
BatchedGemmCoord problem_shape,
GemmCoord tile_shape,
GemmCoord cluster_shape,
KernelHardwareInfo hw_info,
int splits,
int max_swizzle,
RasterOrderOptions raster_order_option,
ReductionMode reduction_mode,
DecompositionMode decomposition_mode,
void* workspace,
const uint32_t epilogue_subtile = 1
) {
dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(
problem_shape, tile_shape, cluster_shape);
// Number of k tiles in each output tile
uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k();
initialize(
problem_blocks,
k_tiles_per_output_tile,
cluster_shape,
hw_info,
splits,
max_swizzle,
raster_order_option,
reduction_mode,
decomposition_mode,
workspace,
epilogue_subtile
);
}
// Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
void
initialize(
dim3 problem_blocks,
uint32_t k_tiles_per_output_tile,
GemmCoord cluster_shape,
KernelHardwareInfo hw_info,
int splits,
int max_swizzle,
RasterOrderOptions raster_order_option,
ReductionMode reduction_mode,
DecompositionMode decomposition_mode,
void* workspace,
const uint32_t epilogue_subtile = 1
) {
UnderlyingParams underlying_params;
underlying_params.initialize(
problem_blocks,
cluster_shape,
hw_info,
max_swizzle,
raster_order_option
);
auto problem_blocks_l = problem_blocks.z;
auto problem_blocks_m = round_up(problem_blocks.x, (1 << underlying_params.log_swizzle_size_) * cluster_shape.m());
auto problem_blocks_n = round_up(problem_blocks.y, (1 << underlying_params.log_swizzle_size_) * cluster_shape.n());
uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l;
// Reduction workspace is at the beginning of the workspace. Lock workspace follows.
void* reduction_workspace = workspace;
if (decomposition_mode == DecompositionMode::SplitK ||
(decomposition_mode == DecompositionMode::Heuristic && splits > 1)) {
// Short circuit to basic split-K decomposition
// Don't split by more than the available number of SMs
if (splits > hw_info.sm_count) {
splits = hw_info.sm_count;
}
// Don't split by more than the K tile iterations
//
// splits is almost certainly nonnegative here (e.g., hw_info.sm_count,
// despite being an int, is a count), so it can safely be converted to unsigned
// in the comparison to avoid a signed-unsigned comparison warning-as-error.
if (static_cast<decltype(k_tiles_per_output_tile)>(splits) > k_tiles_per_output_tile) {
splits = k_tiles_per_output_tile;
}
set_params_basic(
underlying_params,
problem_blocks_m,
problem_blocks_n,
problem_blocks_l,
splits,
k_tiles_per_output_tile,
reduction_workspace,
reduction_mode
);
return;
}
// Calculate the maximum number of blocks from clusters of shape cluster_shape that we
// can fit within sm_count SMs.
dim3 grid = get_grid_shape(
problem_blocks,
cluster_shape,
hw_info,
max_swizzle,
raster_order_option
);
uint64_t ctas_per_wave = grid.x * grid.y;
// The number of output tiles to be computed in stream-K and data-parallel fashion, respectively.
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile, decomposition_mode);
uint64_t dp_tiles = output_tiles - sk_tiles;
// Calculate the number of work units covering the data-parallel and stream-K tiles.
// A "work unit" is a single index in the linearized ID space used by the scheduler.
// We distinguish it from a "block," which is typically tied to a hardware unit
// (e.g., the callers into this scheduler will be persistent thread blocks).
// A work unit can encompass multiple output tiles worth of work (as will be the
// case for stream-K blocks).
// Since splitting is not required for data-parallel tiles, only one data-parallel unit
// is needed per data-parallel tile.
uint64_t dp_units = dp_tiles;
// Number of k iterations computed by the stream-K units as a whole
uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles;
// If there are stream-K tiles to compute and a sufficiently large number of k iterations
// across them, they will be covered by a single wave of persistent threadblocks. Thus, there
// will be as many work units as there are threadblocks in a single wave.
//
// When the total k iterations across stream-K tiles is too small to justify distributing
// across an entire wave of blocks, we instead distribute the iterations over a smaller
// set of blocks.
// Calculate the number of stream-K units that would be needed if each stream-K unit
// computed the minimum allowable k iterations. Truncate this to be in units of clusters.
auto cluster_size = cluster_shape.m() * cluster_shape.n();
uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_);
min_sized_sk_units = (min_sized_sk_units / cluster_size) * cluster_size;
uint64_t sk_units = platform::min(ctas_per_wave, min_sized_sk_units);
if (decomposition_mode == DecompositionMode::DataParallel ||
(decomposition_mode == DecompositionMode::Heuristic && sk_tiles == 0) ||
sk_units == 0) {
// Short circuit to basic data-parallel decomposition
set_params_basic(
underlying_params,
problem_blocks_m,
problem_blocks_n,
problem_blocks_l,
/* splits = */ 1,
k_tiles_per_output_tile,
reduction_workspace,
reduction_mode
);
return;
}
bool do_separate_reduction = should_perform_separate_reduction(
epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave);
// Determine the number of stream-K groups that will be used. We currently use
// max_sk_groups_ unless this extends beyond the extent of the dimension over
// which the problem is rasterized. For example, if the tiled problem shape
// (in CTA_M x CTA_N representation) when using 1x1 clusters is 4x16,
// and we rasterize along the M dimension, we choose 4 groups, rather than 8.
// If the cluster shape is 2x1, we choose 2 groups (CTA_M / CLUSTER_M).
uint32_t max_groups_problem;
if (underlying_params.raster_order_ == RasterOrder::AlongM) {
max_groups_problem = problem_blocks_m / cluster_shape.m();
}
else {
max_groups_problem = problem_blocks_n / cluster_shape.n();
}
// Select the number of groups that will be use. We start with the maximum
// number of potential groups, and iterate down looking for a group size that
// evenly divides the stream-K units and tiles, and for which the resulting
// number of K tiles per stream-K unit remains above min_iters_per_sk_unit_
uint32_t groups = platform::min(max_groups_problem, uint32_t(max_sk_groups_));
// Grouping is disabled when separate reduction is used
if (do_separate_reduction) {
groups = 1;
}
uint32_t fallback_groups = 0;
auto sk_cluster_tiles = sk_tiles / cluster_size;
auto sk_cluster_units = sk_units / cluster_size;
auto sk_splits_too_small = [&](uint32_t g) {
// Check whether the number of K tiles computed per stream-K unit is less
// than min_iters_per_sk_unit_
auto total_sk_k_tiles = (sk_tiles / g) * k_tiles_per_output_tile;
auto k_tiles_per_sk_unit = total_sk_k_tiles / (sk_units / g);
return k_tiles_per_sk_unit < min_iters_per_sk_unit_;
};
auto is_ideal_grouping = [&](uint32_t g) {
// An ideal grouping will evenly divide stream-K clusters, evenly divide
// stream-K tiles, and not result in stream-K splits that are too small.
return (sk_cluster_units % g == 0) && (sk_cluster_tiles % g == 0) && !sk_splits_too_small(g);
};
auto is_valid_grouping = [&](uint32_t g) {
// A grouping is valid, but not ideal, if it evenly divides the
// stream-K clusters and does not result in stream-K splits that are
// too small. Such a setting can be used as a fallback option in the
// case that an ideal grouping is not achievable
return sk_cluster_units % g == 0 && !sk_splits_too_small(g);
};
while (groups > 1 && !is_ideal_grouping(groups)) {
if (fallback_groups == 0 && is_valid_grouping(groups)) {
// Set fallback groups once in preference for a larger number of groups.
fallback_groups = groups;
}
--groups;
}
// If groups == 1, we did not find a group count that satisfies all criteria. If we have
// found a fallback group count, use this instead.
if (groups == 1 && fallback_groups > 0) {
groups = fallback_groups;
}
auto sk_units_per_group = sk_units / groups;
// sk_tiles is guaranteed to be divisible by cluster_size because it is calculated as:
// sk_tiles = (waves <= 2) ? total_tiles : (sm_count + (total_tiles % sm_count))
// Both total_tiles and sm_count are multiples of cluster size due to padding added
// prior to kernel launch.
uint64_t sk_clustered_tiles = sk_tiles / cluster_size;
uint64_t sk_clustered_tiles_per_group = sk_clustered_tiles / groups;
uint64_t sk_tiles_per_group = sk_clustered_tiles_per_group * cluster_size;
// Groups that will process an extra stream-K tile cluster. These differ from "big_units," which
// are stream-K units within a group that process an extra K chunk.
uint64_t sk_big_groups = sk_clustered_tiles % groups;
uint64_t k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group;
// Number of k tiles computed per stream-K unit
uint64_t k_tiles_per_sk_unit = k_tiles_per_group / sk_units_per_group;
uint32_t reduction_units = 0;
// Use separate reduction when we have less than one wave of output tiles (dp_tiles == 0)
// and when each tile will be operated on by at least two stream-K units (sk_units > 2 * sk_tiles)
if (do_separate_reduction) {
// Each reduction unit will reduce the partials of an epilogue subtile for
// a given output tile and compute the epilogue. Thus, there are as many reduction
// units as there are epilogue subtiles.
reduction_units = sk_tiles * epilogue_subtile;
}
else if (decomposition_mode == DecompositionMode::Heuristic && sk_tiles < sk_units && sk_units % sk_tiles == 0) {
// If the number of stream-K units is a multiple of the number of stream-K tiles, then
// the problem can leverage a basic split-K decomposition for the stream-K tiles.
// This case happens when separate reduction is disable.
uint32_t sk_splits = static_cast<uint32_t>(sk_units / sk_tiles);
set_params_basic(
underlying_params,
problem_blocks_m,
problem_blocks_n,
problem_blocks_l,
sk_splits,
k_tiles_per_output_tile,
reduction_workspace,
reduction_mode
);
return;
}
divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_;
divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_;
divmod_batch_ = underlying_params.divmod_batch_;
divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile);
divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_;
divmod_sk_groups_ = FastDivmodU64(static_cast<uint64_t>(groups));
divmod_sk_units_per_group_ = FastDivmodU64(static_cast<uint64_t>(sk_units / groups));
// Override divmod_clusters_mnl_ to be the number of cluster-sized stream-K units.
// This setting ensures that the use of this divmod for stream-K decompositions
// is essentially a no-op.
divmod_clusters_mnl_ = FastDivmodU64(sk_units / cluster_size);
splits_ = 1;
log_swizzle_size_ = underlying_params.log_swizzle_size_;
units_per_problem_ = static_cast<uint32_t>(dp_units + sk_units);
raster_order_ = underlying_params.raster_order_;
// Assign big_units_ assuming that group count == 1. This is unused by stream-K
// when group count > 1.
big_units_ = static_cast<uint32_t>(k_tiles_per_group % k_tiles_per_sk_unit);
big_groups_ = static_cast<uint32_t>(sk_big_groups);
reduction_workspace_ = reduction_workspace;
sk_tiles_ = sk_tiles;
sk_units_ = static_cast<uint32_t>(sk_units);
k_tiles_per_sk_unit_ = static_cast<uint32_t>(k_tiles_per_sk_unit);
reduction_mode_ = reduction_mode;
divmod_epilogue_subtile_ = FastDivmodU64(epilogue_subtile);
separate_reduction_units_ = reduction_units;
}
// Given the inputs, computes the physical grid we should launch.
// This variant of the method should only be used when
// problem_shape and tile_shape contain modes of only rank 1.
CUTLASS_HOST_DEVICE
static dim3
get_grid_shape(
BatchedGemmCoord problem_shape,
GemmCoord cta_shape,
GemmCoord cluster_shape,
KernelHardwareInfo hw_info,
int max_swizzle_size,
RasterOrderOptions raster_order_option
) {
dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, cta_shape, cluster_shape);
return get_grid_shape(
problem_blocks,
cluster_shape,
hw_info,
max_swizzle_size,
raster_order_option
);
}
// Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
CUTLASS_HOST_DEVICE
static dim3
get_grid_shape(
dim3 problem_blocks,
GemmCoord cluster_shape,
KernelHardwareInfo hw_info,
int max_swizzle_size,
RasterOrderOptions raster_order_option
) {
// Call into the underlying get_grid_shape method, but do not allow the grid shape returned
// to be truncated based on the number of output tiles in the problem.
return UnderlyingParams::get_grid_shape(
problem_blocks,
cluster_shape,
hw_info,
max_swizzle_size,
raster_order_option,
/* truncate_by_problem_size = */false
);
}
// Returns the number of stream-K tiles that will be computed amongst `output_tiles` total
// output tiles on a device with `ctas_per_wave` CTAs in each wave.
static uint32_t
get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave, uint32_t k_tiles_per_output_tile, DecompositionMode decomposition_mode) {
uint32_t full_waves = static_cast<uint32_t>(output_tiles / ctas_per_wave);
uint32_t total_waves = static_cast<uint32_t>((output_tiles + ctas_per_wave - 1) / ctas_per_wave);
if (decomposition_mode == DecompositionMode::DataParallel ||
decomposition_mode == DecompositionMode::SplitK) {
return 0;
}
if (decomposition_mode == DecompositionMode::Heuristic) {
if (full_waves == total_waves || k_tiles_per_output_tile <= min_iters_per_sk_unit_) {
// All tiles will be data-parallel tiles if there is either no quantization
// or if there is no work to be split.
return 0;
}
//
// The final wave is not full. Perform some stream-K work.
//
// Rudimentary heuristic: prefer data-parallel decomposition if we have more than
// one wave and the tail wave is more than half full. This is subject to change.
uint64_t tail_tiles = output_tiles - (full_waves * ctas_per_wave);
if (tail_tiles >= (ctas_per_wave / 2)) {
return 0;
}
}
// If there is wave quantization, assign the first two waves worth of tiles to be
// covered by stream-K work and the remainder to be data-parallel. Since we know
// that full_waves == total_waves - 1 in this case, the number of data-parallel
// waves is simply full_waves-1 (unless full_waves == 0).
uint32_t dp_waves = full_waves > 0 ? full_waves - 1 : 0;
uint64_t dp_tiles = dp_waves * ctas_per_wave;
return static_cast<uint32_t>(output_tiles - dp_tiles);
}
CUTLASS_HOST_DEVICE
static uint64_t
get_num_sk_units(GemmCoord cluster_shape, uint64_t ctas_per_wave, uint32_t sk_tiles, uint32_t k_tiles_per_output_tile) {
// Number of k iterations computed by the stream-K units as a whole
uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles;
// Calculate the number of stream-K units that would be needed if each stream-K unit
// computed the minimum allowable k iterations. Truncate this to be in units of clusters.
auto cluster_size = cluster_shape.m() * cluster_shape.n();
uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_);
min_sized_sk_units = (min_sized_sk_units / cluster_size) * cluster_size;
uint64_t sk_units = platform::min(ctas_per_wave, min_sized_sk_units);
return sk_units;
}
// Calculates the size of the workspace needed for holding reduction barriers
CUTLASS_HOST_DEVICE
static int
get_barrier_workspace_size(uint64_t num_tiles, uint32_t mma_warp_groups, uint32_t barrier_bits) {
auto workspace_bits = num_tiles * mma_warp_groups * barrier_bits;
return round_up_to_l2_alignment(bits_to_bytes(static_cast<int>(workspace_bits)));
}
// Calculates the size of the workspace needed for holding partial outputs from splits
CUTLASS_HOST_DEVICE
static int
get_reduction_workspace_size(uint64_t num_tiles, GemmCoord tile_shape, uint32_t accumulator_bits) {
auto output_tile_size = tile_shape.m() * tile_shape.n();
auto workspace_bits = accumulator_bits * output_tile_size * num_tiles;
return round_up_to_l2_alignment(bits_to_bytes(static_cast<int>(workspace_bits)));
}
#if !defined(__CUDACC_RTC__)
static void
get_workspace_component_sizes(
dim3 problem_blocks,
uint32_t k_tiles_per_output_tile,
GemmCoord tile_shape,
GemmCoord cluster_shape,
int& barrier_workspace_size,
int& reduction_workspace_size,
KernelHardwareInfo const& hw_info,
int splits,
int max_swizzle,
RasterOrderOptions raster_order_option,
DecompositionMode decomposition_mode,
uint32_t mma_warp_groups,
uint32_t barrier_bits,
uint32_t accumulator_bits,
uint32_t epilogue_subtile = 1) {
auto log_swizzle_size = UnderlyingParams::get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle);
problem_blocks.x = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m());
problem_blocks.y = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n());
// Workspace is needed only for output tiles that will be split. Thus, we first determine the number
// of output tiles that will be split, and then calculate the workspace needed to cover these.
uint64_t output_tiles = problem_blocks.x * problem_blocks.y * problem_blocks.z;
if (decomposition_mode == DecompositionMode::DataParallel) {
barrier_workspace_size = 0;
reduction_workspace_size = 0;
}
else if (decomposition_mode == DecompositionMode::SplitK ||
(decomposition_mode == DecompositionMode::Heuristic && splits > 1)) {
// Basic split-K variant requires workspace for all output tiles
barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups, barrier_bits);
reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits);
}
else {
KernelHardwareInfo new_hw_info;
new_hw_info.device_id = hw_info.device_id;
new_hw_info.sm_count = hw_info.sm_count;
if (new_hw_info.sm_count <= 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
new_hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(new_hw_info.device_id);
}
dim3 grid = get_grid_shape(
problem_blocks,
cluster_shape,
new_hw_info,
max_swizzle,
raster_order_option
);
uint64_t ctas_per_wave = grid.x * grid.y;
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, static_cast<uint32_t>(k_tiles_per_output_tile), decomposition_mode);
uint64_t sk_units = get_num_sk_units(cluster_shape, ctas_per_wave, sk_tiles, k_tiles_per_output_tile);
uint64_t dp_tiles = output_tiles - sk_tiles;
uint64_t reduction_tiles = sk_tiles;
if (should_perform_separate_reduction(epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave)) {
// In separate reduction, each peer writes to its own location in scratch space.
// Thus, for separate reduction, we need as many reduction tiles per output tile
// as there are the maximum number of peers that can collaborate on an output tile.
reduction_tiles *= max_peers_per_tile(sk_units, sk_tiles);
}
// Though separate reduction requires a larger reduction workspace, only one barrier
// is needed per output tile. Each peer will increment the barrier by one once the peer has
// written its accumulator to scratch space. The separate reduction unit will only begin
// performing the reduction when the barrier has reached the number of peers for the output tile.
barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups, barrier_bits);
reduction_workspace_size = get_reduction_workspace_size(reduction_tiles, tile_shape, accumulator_bits);
}
}
#endif // !defined(__CUDACC_RTC__)
// Returns whether the kernel is configured in a manner for which separate reduction should be used
CUTLASS_HOST_DEVICE
static bool
should_perform_separate_reduction(uint32_t epilogue_subtile, uint64_t sk_units, uint64_t sk_tiles, uint64_t dp_tiles, uint64_t ctas_per_wave) {
// We perform separate reduction if we have fewer than one wave of output tiles
// and each output tile is covered by at least to stream-K units. When sk_units is
// multiple of sk_tiles, will choose basic split-k path instead of separate reduction for now.
return (epilogue_subtile != 1) && (dp_tiles == 0) && (sk_units > 2u * sk_tiles) &&
(sk_units + sk_tiles * epilogue_subtile <= ctas_per_wave);
}
// Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when
// problem_shape and tile_shape contain modes of only rank 1.
static int
get_workspace_size(
BatchedGemmCoord problem_shape,
GemmCoord tile_shape,
GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info,
int splits,
int max_swizzle,
RasterOrderOptions raster_order_option,
DecompositionMode decomposition_mode,
uint32_t mma_warp_groups,
uint32_t barrier_bits,
uint32_t element_accumulator_bits,
uint32_t epilogue_subtile) {
dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape);
uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k();
return get_workspace_size(
problem_blocks,
k_tiles_per_output_tile,
tile_shape,
cluster_shape,
hw_info,
splits,
max_swizzle,
raster_order_option,
decomposition_mode,
mma_warp_groups,
barrier_bits,
element_accumulator_bits,
epilogue_subtile
);
}
// Version of get_workspace_size that takes in as input the number of CTAs in the M and N dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
static int
get_workspace_size(
dim3 problem_blocks,
uint32_t k_tiles_per_output_tile,
GemmCoord tile_shape,
GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info,
int splits,
int max_swizzle,
RasterOrderOptions raster_order_option,
DecompositionMode decomposition_mode,
uint32_t mma_warp_groups,
uint32_t barrier_bits,
uint32_t element_accumulator_bits,
uint32_t epilogue_subtile = 1) {
int barrier_workspace_size = 0;
int reduction_workspace_size = 0;
#if !defined(__CUDACC_RTC__)
get_workspace_component_sizes(
problem_blocks,
k_tiles_per_output_tile,
tile_shape,
cluster_shape,
barrier_workspace_size,
reduction_workspace_size,
hw_info,
splits,
max_swizzle,
raster_order_option,
decomposition_mode,
mma_warp_groups,
barrier_bits,
element_accumulator_bits,
epilogue_subtile
);
#endif
return barrier_workspace_size + reduction_workspace_size;
}
// Initialize the workspace to be used for the kernel. This variant of the method should only be used when
// problem_shape and tile_shape contain modes of only rank 1.
static cutlass::Status
initialize_workspace(
void* workspace,
cudaStream_t stream,
BatchedGemmCoord problem_shape,
GemmCoord tile_shape,
GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info,
int splits,
int max_swizzle,
RasterOrderOptions raster_order_option,
DecompositionMode decomposition_mode,
uint32_t mma_warp_groups,
uint32_t barrier_bits,
uint32_t element_accumulator_bits,
uint32_t epilogue_subtile) {
dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape);
uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k();
return initialize_workspace(
workspace,
stream,
problem_blocks,
k_tiles_per_output_tile,
tile_shape,
cluster_shape,
hw_info,
splits,
max_swizzle,
raster_order_option,
decomposition_mode,
mma_warp_groups,
barrier_bits,
element_accumulator_bits,
epilogue_subtile
);
}
// Version of initialize_workspace that takes in as input the number of CTAs in the M and N dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
static cutlass::Status
initialize_workspace(
void* workspace,
cudaStream_t stream,
dim3 problem_blocks,
uint32_t k_tiles_per_output_tile,
GemmCoord tile_shape,
GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info,
int splits,
int max_swizzle,
RasterOrderOptions raster_order_option,
DecompositionMode decomposition_mode,
uint32_t mma_warp_groups,
uint32_t barrier_bits,
uint32_t element_accumulator_bits,
uint32_t epilogue_subtile = 1) {
#if !defined(__CUDACC_RTC__)
int barrier_workspace_size = 0;
int reduction_workspace_size = 0;
get_workspace_component_sizes(
problem_blocks,
k_tiles_per_output_tile,
tile_shape,
cluster_shape,
barrier_workspace_size,
reduction_workspace_size,
hw_info,
splits,
max_swizzle,
raster_order_option,
decomposition_mode,
mma_warp_groups,
barrier_bits,
element_accumulator_bits,
epilogue_subtile
);
if (barrier_workspace_size > 0) {
if (workspace == nullptr) {
return Status::kErrorWorkspaceNull;
}
// Only the barrier workspace needs to be cleared for stream-K.
// Barrier workspace follows reduction workspace.
uint8_t* barrier_workspace = reinterpret_cast<uint8_t*>(workspace) + reduction_workspace_size;
return zero_workspace(static_cast<void*>(barrier_workspace), barrier_workspace_size, stream);
}
#endif // !defined(__CUDACC_RTC__)
return Status::kSuccess;
}
void
set_params_basic(
UnderlyingParams const& underlying_params,
uint32_t blocks_m,
uint32_t blocks_n,
uint32_t blocks_l,
uint32_t splits,
uint32_t k_tiles_per_output_tile,
void* reduction_workspace,
ReductionMode reduction_mode) {
divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_;
divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_;
divmod_batch_ = FastDivmodU64(blocks_m * blocks_n);
divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile);
divmod_sk_groups_ = FastDivmodU64(1u);
auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * underlying_params.divmod_cluster_shape_minor_.divisor;
divmod_clusters_mnl_ = FastDivmodU64((blocks_m * blocks_n * blocks_l) / cluster_size);
splits_ = splits;
divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_;
log_swizzle_size_ = underlying_params.log_swizzle_size_;
units_per_problem_ = blocks_m * blocks_n * blocks_l;
raster_order_ = underlying_params.raster_order_;
big_units_ = k_tiles_per_output_tile % splits;
reduction_workspace_ = reduction_workspace;
reduction_mode_ = reduction_mode;
k_tiles_per_sk_unit_ = k_tiles_per_output_tile / splits;
// No stream-K work is performed for "basic" data-parallel and split-K decompositions
sk_tiles_ = 0;
sk_units_ = 0;
divmod_sk_units_per_group_ = FastDivmodU64(1u);
separate_reduction_units_ = 0;
}
private:
// Round up number of bytes to the nearest multiple of L2 cache line alignment
CUTLASS_HOST_DEVICE
static int
round_up_to_l2_alignment(int bytes) {
constexpr static uint32_t L2CacheLineSizeBytes = 128;
return (bytes + L2CacheLineSizeBytes - 1) / L2CacheLineSizeBytes * L2CacheLineSizeBytes;
}
};
////////////////////////////////////////////////////////////////////////////////
// Parameters for SM90 persistent group scheduler (only used for Grouped Gemms)
template<class ProblemShape>
struct PersistentTileSchedulerSm90GroupParams {
enum class RasterOrder {
AlongM,
AlongN
};
enum class RasterOrderOptions {
Heuristic,
AlongM,
AlongN
};
FastDivmodU64Pow2 divmod_cluster_shape_major_{};
FastDivmodU64Pow2 divmod_cluster_shape_minor_{};
FastDivmodU64 divmod_cta_shape_m_{};
FastDivmodU64 divmod_cta_shape_n_{};
uint64_t blocks_across_problem_ = 0;
bool pre_processed_problem_shapes = true;
int32_t log_swizzle_size_ = 0;
RasterOrder raster_order_ = RasterOrder::AlongN;
int32_t groups_ = 0;
ProblemShape* problem_shapes_ = nullptr;
GemmCoord cta_shape_;
GemmCoord cluster_shape_;
// Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
void
initialize(
dim3 problem_blocks,
int32_t groups,
ProblemShape* problem_shapes,
ProblemShape const* host_problem_shapes,
GemmCoord cta_shape,
GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info,
int max_swizzle_size,
RasterOrderOptions raster_order_option
) {
CUTLASS_UNUSED(hw_info);
// Round up to nearest multiple of swizzle_size along each mode
auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size);
auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m());
auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n());
RasterOrder raster_order = get_rasterization_order(
problem_blocks_m,
problem_blocks_n,
raster_order_option
);
//
// Set members
//
groups_ = groups;
problem_shapes_ = problem_shapes;
cta_shape_ = cta_shape;
cluster_shape_ = cluster_shape;
blocks_across_problem_ = problem_blocks.x * problem_blocks.y * problem_blocks.z;
pre_processed_problem_shapes = (host_problem_shapes == nullptr) ? false : true;
log_swizzle_size_ = log_swizzle_size;
raster_order_ = raster_order;
if (raster_order == RasterOrder::AlongN) {
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n());
divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.m());
}
else {
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m());
divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n());
}
divmod_cta_shape_m_ = FastDivmodU64(cta_shape_.m());
divmod_cta_shape_n_ = FastDivmodU64(cta_shape_.n());
}
// Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
CUTLASS_HOST_DEVICE
static dim3
get_tiled_cta_shape_mnl(GemmCoord cluster_shape, uint32_t cta_m, uint32_t cta_n) {
// Round up to nearest multiple of cluster dim along each mode
auto problem_blocks_m = ((cta_m + cluster_shape.m() - 1) / cluster_shape.m()) * cluster_shape.m();
auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n();
return {
static_cast<uint32_t>(cta_m),
static_cast<uint32_t>(cta_n),
static_cast<uint32_t>(1) // Only a single batch per group is currently supported
};
}
// Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
// for which using CuTe algebra for calculating tile shapes is easiest.
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
dim3 problem_blocks,
GemmCoord cluster_shape,
KernelHardwareInfo hw_info,
int max_swizzle_size,
RasterOrderOptions raster_order_option,
bool truncate_by_problem_size=true) {
int const sm_count = hw_info.sm_count;
// Round up to nearest multiple of swizzle_size along each mode
auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size);
auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m());
auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n());
int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z;
RasterOrder raster_order = get_rasterization_order(
problem_blocks_m,
problem_blocks_n,
raster_order_option
);
dim3 launch_grid;
if (raster_order == RasterOrder::AlongN) {
launch_grid = dim3(cluster_shape.m(), 1, 1);
}
else {
launch_grid = dim3(1, cluster_shape.n(), 1);
}
auto possibly_truncate = [&](int x, int y) {
if (truncate_by_problem_size) {
return platform::min(x, y);
}
else {
return x;
}
};
// The else path is generic, however, we can avoid some divs if we know cluster size is 1
auto cluster_size = cluster_shape.m() * cluster_shape.n();
if (cluster_size == 1) {
if (raster_order == RasterOrder::AlongN) {
launch_grid.y = possibly_truncate(sm_count, problem_blocks_total);
}
else {
launch_grid.x = possibly_truncate(sm_count, problem_blocks_total);
}
}
else {
// Optimal grid size calculation is based on
// GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU
// Hence, maximum SMs per GPC = 18
constexpr int max_sm_per_gpc = 18;
// Provided SM count could possibly be less than the assumed maximum SMs per GPC
auto cluster_size = cluster_shape.m() * cluster_shape.n();
int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc;
int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size);
int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc;
// The calculation below allows for larger grid size launch for different GPUs.
int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc;
int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size);
cta_per_device += max_cta_occupancy_per_residual_gpc;
cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device;
if (raster_order == RasterOrder::AlongN) {
launch_grid.y = possibly_truncate(
cta_per_device / cluster_shape.m(),
problem_blocks_total / cluster_shape.m());
}
else {
launch_grid.x = possibly_truncate(
cta_per_device / cluster_shape.n(),
problem_blocks_total / cluster_shape.n());
}
}
return launch_grid;
}
CUTLASS_HOST_DEVICE
static int32_t
get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) {
int min_cta_dim = platform::min(problem_ctas_m, problem_ctas_n);
if (max_swizzle_size >= 8 && min_cta_dim >= 6) {
return 3;
}
else if (max_swizzle_size >= 4 && min_cta_dim >= 3) {
return 2;
}
else if (max_swizzle_size >= 2 && min_cta_dim >= 2) {
return 1;
}
else {
return 0;
}
}
CUTLASS_HOST_DEVICE
static RasterOrder
get_rasterization_order(
uint32_t tiles_m,
uint32_t tiles_n,
RasterOrderOptions raster_order_option
) {
if (raster_order_option == RasterOrderOptions::Heuristic) {
if (tiles_n > tiles_m) {
return RasterOrder::AlongM;
}
else {
return RasterOrder::AlongN;
}
}
else {
switch (raster_order_option) {
case RasterOrderOptions::AlongN:
return RasterOrder::AlongN;
break;
default:
return RasterOrder::AlongM;
}
}
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace detail
} // namespace kernel
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////