134 lines
5.6 KiB
C++
134 lines
5.6 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include "cutlass/fast_math.h"
|
|
#include "cute/layout.hpp"
|
|
|
|
namespace cutlass::gemm::kernel::detail {
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Persistent Thread Block (TB) scheduler
|
|
class PersistentTileSchedulerSm90 {
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
private:
|
|
uint32_t blocks_per_problem_;
|
|
uint32_t current_work_linear_idx_;
|
|
uint32_t grid_blocks_total_;
|
|
|
|
FastDivmod divmod_batch_;
|
|
FastDivmod divmod_grid_y_;
|
|
FastDivmod divmod_blk_m_;
|
|
|
|
struct WorkTileInfo {
|
|
int32_t M_idx = 0;
|
|
int32_t N_idx = 0;
|
|
int32_t L_idx = 0;
|
|
uint32_t is_valid_tile = false;
|
|
};
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
public:
|
|
|
|
template<class ProblemShapeMNKL, class TileShape, class ClusterShape>
|
|
CUTLASS_DEVICE
|
|
PersistentTileSchedulerSm90(ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, ClusterShape cluster_shape) {
|
|
// We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic
|
|
static_assert(is_static<TileShape>::value);
|
|
static_assert(is_static<ClusterShape>::value);
|
|
|
|
// Round up to nearest multiple of cluster dim along each mode
|
|
auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_blk_shape_mnl(
|
|
problem_shape_mnkl, tile_shape, cluster_shape);
|
|
|
|
blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks_l;
|
|
current_work_linear_idx_ = (int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y);
|
|
grid_blocks_total_ = int(gridDim.x) * int(gridDim.y);
|
|
|
|
// Pre-compute our fast div/mods for rasterization so we don't have to pay for DIVs
|
|
divmod_batch_ = FastDivmod(problem_blocks_m * problem_blocks_n);
|
|
divmod_grid_y_ = FastDivmod(size<1>(cluster_shape));
|
|
divmod_blk_m_ = FastDivmod(problem_blocks_m);
|
|
}
|
|
|
|
CUTLASS_DEVICE
|
|
WorkTileInfo
|
|
get_current_work() const {
|
|
// Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices
|
|
int work_idx_l, remainder;
|
|
divmod_batch_(work_idx_l, remainder, current_work_linear_idx_);
|
|
|
|
int blk_per_grid_dim, dontcare;
|
|
divmod_grid_y_(blk_per_grid_dim, dontcare, remainder);
|
|
|
|
int block_idx_m, block_idx_n;
|
|
divmod_blk_m_(block_idx_n, block_idx_m, blk_per_grid_dim);
|
|
int work_idx_m = block_idx_m;
|
|
int work_idx_n = (block_idx_n * gridDim.y) + blockIdx.y;
|
|
|
|
return {work_idx_m, work_idx_n, work_idx_l, current_work_linear_idx_ < blocks_per_problem_};
|
|
}
|
|
|
|
CUTLASS_DEVICE
|
|
void
|
|
advance_to_next_work(uint32_t advance_count = 1) {
|
|
current_work_linear_idx_ += grid_blocks_total_ * advance_count;
|
|
}
|
|
|
|
// Given the inputs, computes the total number of output blocks this problem will compute over
|
|
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
|
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
|
CUTLASS_HOST_DEVICE constexpr static
|
|
dim3
|
|
get_tiled_blk_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) {
|
|
// Across M and N is our Cluster tile, so we must round up the blocks to the nearest whole number of Cluster tiles
|
|
auto blk_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(blk_shape)));
|
|
auto blk_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(blk_shape)));
|
|
|
|
// Round up to nearest multiple of cluster dim along each mode
|
|
int problem_blocks_m = round_up(blk_m, cute::size<0>(cluster_shape));
|
|
int problem_blocks_n = round_up(blk_n, cute::size<1>(cluster_shape));
|
|
|
|
// Cluster tile does not span the batch mode, so no extra rounding up required for it
|
|
int problem_blocks_l = int(cute::size<3>(problem_shape_mnkl));
|
|
return {uint32_t(problem_blocks_m), uint32_t(problem_blocks_n), uint32_t(problem_blocks_l)};
|
|
}
|
|
};
|
|
|
|
} // namespace cutlass::gemm::kernel::detail
|