Updates for 3.2 release (#1065)

This commit is contained in:
ANIKET SHIVAM 2023-08-25 17:05:46 -10:00 committed by GitHub
parent 27de343535
commit a88c41cf8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 904 additions and 257 deletions

View File

@ -2,10 +2,14 @@
## 2023
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, and Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
## 2022

View File

@ -71,7 +71,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "52_fp8_hopper_warp_specialized_gemm\n\n"
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
@ -93,7 +93,7 @@ struct Options {
out
<< "\n\nExamples:\n\n"
<< "$ " << "52_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}

View File

@ -54,6 +54,13 @@ struct SyncthreadsSync {
}
};
struct SyncwarpSync {
CUTLASS_DEVICE
static void sync() {
__syncwarp();
}
};
template <
int ThreadCount,
int BarrierId
@ -311,6 +318,60 @@ private:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads)
* via an API that mirrors that of NamedBarrierManager
*
* @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization
**/
template <
class Synchronizer,
uint32_t ThreadCount_
>
struct SyncManager {
// Number of threads participating in the barrier
static constexpr uint32_t ThreadCount = ThreadCount_;
using BarrierSync = cutlass::GenericBarrier<Synchronizer>;
// Underlying type used by all barriers for synchronization.
using T = typename BarrierSync::T;
CUTLASS_DEVICE
static
void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) {
BarrierSync::wait_lt_helper(lock_ptr, thread_idx, flag_idx, count);
}
CUTLASS_DEVICE
static void
wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val);
}
CUTLASS_DEVICE
static void
wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val);
}
CUTLASS_DEVICE
static void
arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) {
BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val);
}
CUTLASS_DEVICE
static void
arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) {
BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -67,7 +67,7 @@ sm90_get_tma_dispatch_policy() {
constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{}));
constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v<Schedule> ? 256 : 128);
constexpr int ReuseSmemC = sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>;
constexpr int ReuseSmemC = (sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>) && (sizeof_bits_v<ElementD> > 8);
constexpr int StagesD = 2;
constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles;
@ -98,7 +98,7 @@ sm90_get_epilogue_smem_swizzle_layout_atom() {
}
// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
template <class Element, class EpilogueTileType, class Schedule>
template <class ElementD, class EpilogueTileType, class Schedule>
constexpr auto
sm90_compute_tile_shape_or_override() {
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
@ -107,7 +107,12 @@ sm90_compute_tile_shape_or_override() {
return Shape<_128,_32>{};
}
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
return Shape<_64,_32>{};
if constexpr (sizeof_bits_v<ElementD> == 8) {
return Shape<_64,_64>{};
}
else {
return Shape<_64,_32>{};
}
}
else {
static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule.");

View File

@ -34,6 +34,7 @@
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/arch/cluster_sm90.hpp"
namespace cutlass::gemm::kernel::detail {
@ -205,18 +206,14 @@ public:
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__.
#if defined(__CUDA_ARCH__)
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
if (raster_order == RasterOrder::AlongN) {
cluster_minor_offset = blockIdx.x;
cluster_minor_offset = cta_m_in_cluster;
}
else {
cluster_minor_offset = blockIdx.y;
cluster_minor_offset = cta_n_in_cluster;
}
#else
CUTLASS_ASSERT(false && "This line should never be reached");
#endif
uint64_t cluster_idx_minor, cluster_idx_major;

View File

@ -141,7 +141,7 @@ public:
uint32_t splits_ = 1;
// Number of tiled k iterations required to compute a single output tile.
uint32_t k_iter_per_tile_ = 0;
uint32_t k_tiles_per_output_tile_ = 0;
// Number of stream-K or split-K work units that compute an extra k iteration.
// This is done to handle residuals in dividing up the k iteration space.
@ -160,7 +160,7 @@ public:
// Number of tiled k iterations computed by each stream-K work unit. This
// can potentially cover more than one output tile.
uint32_t k_iter_per_sk_unit_ = 0;
uint32_t k_tiles_per_sk_unit_ = 0;
};
// Sink scheduler params as a member
@ -189,9 +189,9 @@ public:
uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l;
// Number of k iterations each tile computes (this is just the number of k iterations
// in the problem's K dimension)
uint32_t k_iter_per_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) / cute::size<2>(tile_shape);
// Number of k tile iterations in each output tile
uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) /
cute::size<2>(tile_shape);
UnderlyingArguments underlying_args;
underlying_args.max_swizzle_size = 1;
@ -216,11 +216,11 @@ public:
// splits is almost certainly nonnegative here (e.g., hw_info.sm_count,
// despite being an int, is a count), so it can safely be converted to unsigned
// in the comparison to avoid a signed-unsigned comparison warning-as-error.
splits = static_cast<decltype(k_iter_per_tile)>(splits) > k_iter_per_tile ? k_iter_per_tile : splits;
splits = static_cast<decltype(k_tiles_per_output_tile)>(splits) > k_tiles_per_output_tile ? k_tiles_per_output_tile : splits;
return get_params_basic(
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
splits, k_iter_per_tile, reduction_workspace);
splits, k_tiles_per_output_tile, reduction_workspace);
}
// Calculate the maximum number of blocks from clusters of shape cluster_shape that we
@ -229,7 +229,7 @@ public:
uint64_t ctas_per_wave = grid.x * grid.y;
// The number of output tiles to be computed in stream-K and data-parallel fashion, respectively.
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave);
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile);
uint64_t dp_tiles = output_tiles - sk_tiles;
// Calculate the number of work units covering the data-parallel and stream-K tiles.
@ -243,7 +243,7 @@ public:
uint64_t dp_units = dp_tiles;
// Number of k iterations computed by the stream-K units as a whole
uint64_t k_iter_sk_total = k_iter_per_tile * sk_tiles;
uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles;
// If there are stream-K tiles to compute and a sufficiently large number of k iterations
// across them, they will be covered by a single wave of persistent threadblocks. Thus, there
@ -255,7 +255,7 @@ public:
// Calculate the number of stream-K units that would be needed if each stream-K unit
// computed the minimum allowable k iterations. Truncate this to be in units of clusters.
uint64_t min_sized_sk_units = (k_iter_sk_total / min_iters_per_sk_unit_);
uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_);
min_sized_sk_units = (min_sized_sk_units / cute::size(cluster_shape)) * cute::size(cluster_shape);
uint64_t sk_units = min(ctas_per_wave, min_sized_sk_units);
@ -264,7 +264,7 @@ public:
// Short circuit to basic data-parallel decomposition
return get_params_basic(
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
1, k_iter_per_tile, reduction_workspace);
1, k_tiles_per_output_tile, reduction_workspace);
}
// If the number of stream-K units is a multiple of the number of stream-K tiles, then
@ -274,24 +274,24 @@ public:
uint32_t sk_splits = static_cast<uint32_t>(sk_units / sk_tiles);
return get_params_basic(
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
sk_splits, k_iter_per_tile, reduction_workspace);
sk_splits, k_tiles_per_output_tile, reduction_workspace);
}
// Number of k iterations computed per stream-K units
uint64_t k_iter_per_sk_unit = k_iter_sk_total / sk_units;
uint64_t k_tiles_per_sk_unit = k_tiles_sk_total / sk_units;
// Number of stream-K units that need to compute extra iterations in order to cover
// the residual k iterations. This assumes that each such unit computes one additional
// iteration.
uint64_t sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units);
uint64_t sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units);
// The division below is guaranteed to be exact because sk_big_units is guaranteed
// to be a multiple of cluster_size (cute::size(cluster_shape)). This is useful because
// it allows us to use a block's linearized cluster ID to determine whether it is
// a big block. The reasoning behind this guarnatee is explained as follows:
// sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units);
// sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units);
//
// - k_iter_sk_total is a multiple of cluster_size because it is the product
// - k_tiles_sk_total is a multiple of cluster_size because it is the product
// of number of tail tiles and the number of k iterations per tile. Because
// both the number of output tiles and number of available SMs are rounded
// to be multiples of cluster shape, the number of tail tiles
@ -313,12 +313,12 @@ public:
underlying_params.raster_order_,
cluster_shape,
1, // Static k-splitting factor. Unused for stream-K.
k_iter_per_tile,
k_tiles_per_output_tile,
static_cast<uint32_t>(sk_big_units_per_cluster),
reduction_workspace,
sk_tiles,
static_cast<uint32_t>(sk_units),
static_cast<uint32_t>(k_iter_per_sk_unit)
static_cast<uint32_t>(k_tiles_per_sk_unit)
};
}
@ -338,105 +338,32 @@ public:
CUTLASS_DEVICE
WorkTileInfo
get_current_work() const {
return get_current_work_for_linear_idx(current_work_linear_idx_);
return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params);
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx) const {
if (linear_idx >= scheduler_params.units_per_problem_) {
static WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) {
if (linear_idx >= params.units_per_problem_) {
// Invalid work. Return an empty result.
return {0, 0, 0, 0, false, 0};
}
// Determine whether this work unit is a data-parallel or stream-K work unit
bool is_stream_k_unit = linear_idx < scheduler_params.sk_units_;
bool is_stream_k_unit = linear_idx < params.sk_units_;
bool is_split_k = scheduler_params.splits_ > 1;
bool is_split_k = params.splits_ > 1;
// Bypass the stream-K scheduling logic for basic data-parallel or split-K work
if (is_split_k || !is_stream_k_unit) {
// The linearized ID space is in terms of work units, rather than tiles. However,
// to compute the correct block offset for a data-parallel tile, we must convert
// the current ID to the data-parallel tile it corresponds to. Each data-parallel
// unit maps to a single data-parallel tile, but each stream-K unit can map to more
// than one tile. Thus, we must offset the work-unit ID among the data-parallel units
// by the total number of output tiles that will be computed by stream-K units.
//
// The logic below also works for the split-K case, in which sk_units_ and sk_tiles_
// are each 0.
uint64_t linear_work_idx = linear_idx - scheduler_params.sk_units_ + scheduler_params.sk_tiles_;
// Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices
uint64_t work_idx_l, remainder;
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_work_idx);
uint64_t work_idx_k = 0;
if (is_split_k) {
scheduler_params.divmod_k_(work_idx_k, remainder, remainder);
}
uint64_t cta_per_grid_dim, dontcare;
scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
cta_per_grid_dim,
scheduler_params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_minor_,
scheduler_params.divmod_cluster_blk_major_,
scheduler_params.log_swizzle_size_,
scheduler_params.raster_order_);
bool is_final_split = (work_idx_k == scheduler_params.splits_ - 1);
uint32_t k_iter = scheduler_params.k_iter_per_tile_;
if (is_split_k) {
// Determine the number of iterations and starting iteration of this split.
// Doing so requires accounting for residual iterations, which are handled
// by the first big_units_ splits (with big_units_ = tiles % sm_count).
// Offsets for "normal" units. No additional k iterations are performed,
// and big_units_ "big" units preceded us, each of which performed one
// additional iteration. Thus, we must increase our split starting offset
// by big_units_.
int additional_k_iter = 0;
int split_start_offset = scheduler_params.big_units_;
if (work_idx_k < scheduler_params.big_units_) {
// Offsets for "big" units. One additional k iteration is performed,
// and each split preceding us was a big unit, so we must increase
// our split starting offset by our split ID (work_idx_k).
additional_k_iter = 1;
split_start_offset = work_idx_k;
}
// Set up k iteration count and split starting iteration assuming the
// iteration space is evenly split.
k_iter /= scheduler_params.splits_;
work_idx_k *= k_iter;
// Apply any fixup needed to handle residuals
work_idx_k += split_start_offset;
k_iter += additional_k_iter;
}
return {
work_idx_m,
work_idx_n,
static_cast<int32_t>(work_idx_k),
static_cast<int32_t>(work_idx_l),
true,
scheduler_params.k_iter_per_tile_,
k_iter,
k_iter, // remaining iterations
is_final_split
};
// Bypass the stream-K scheduling logic for basic data-parallel or split-K work
return set_non_stream_k_work(linear_idx, params, is_split_k);
}
else {
// This is a stream-K work unit
WorkTileInfo work_tile_info;
set_stream_k_work(params, linear_idx, work_tile_info, /*new_unit = */ true);
return work_tile_info;
}
// This is a stream-K work unit
WorkTileInfo work_tile_info;
set_stream_k_work(linear_idx, work_tile_info, /*new_unit = */ true);
return work_tile_info;
}
// Returns whether the current work_tile_info passed in should continue to be used. This
@ -446,13 +373,24 @@ public:
CUTLASS_DEVICE
bool
continue_current_work(WorkTileInfo& work_tile_info) const {
return continue_current_work_for_linear_idx(
current_work_linear_idx_, work_tile_info, scheduler_params);
}
CUTLASS_DEVICE static
bool
continue_current_work_for_linear_idx(
uint64_t linear_idx,
WorkTileInfo& work_tile_info,
Params const& params) {
work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count;
if (work_tile_info.k_tile_remaining == 0) {
return false;
}
set_stream_k_work(current_work_linear_idx_, work_tile_info, /* new_unit = */ false);
set_stream_k_work(params, linear_idx, work_tile_info, /* new_unit = */ false);
return true;
}
@ -495,6 +433,14 @@ public:
/*truncate_by_problem_size=*/false);
}
// Returns whether fixup is needed for `work_tile_info`.
CUTLASS_HOST_DEVICE
static bool
requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) {
// Fixup is not needed for data-parallel tiles
return work_tile_info.k_tile_count != params.k_tiles_per_output_tile_;
}
// Performs the reduction across splits for a given output tile.
template <class FrgTensorC>
CUTLASS_DEVICE
@ -505,13 +451,25 @@ public:
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx) {
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, 2>;
return fixup_helper<FrgTensorC, BarrierManager>(
params, work_tile_info, accumulators, num_barriers, barrier_idx);
}
// Helper for performing the reduction across splits for a given output tile.
template <class FrgTensorC, class BarrierManager>
CUTLASS_DEVICE
static void
fixup_helper(
Params const& params,
WorkTileInfo const& work_tile_info,
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx) {
using ElementAccumulator = typename FrgTensorC::value_type;
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, 2>;
if (work_tile_info.k_tile_count == params.k_iter_per_tile_) {
// Fixup is not needed for data-parallel tiles
if (!requires_fixup(params, work_tile_info)) {
return;
}
@ -619,21 +577,23 @@ public:
}
}
else {
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
uint64_t cta_per_grid_dim;
uint64_t cluster_dim_idx;
if (params.raster_order_ == RasterOrder::AlongN) {
uint64_t block_idx_m = (work_tile_info.M_idx - blockIdx.x) / gridDim.x;
uint64_t block_idx_m = (work_tile_info.M_idx - cta_m_in_cluster) / cute::size<0>(params.cluster_shape_);
uint64_t block_idx_n = work_tile_info.N_idx;
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n;
cluster_dim_idx = blockIdx.x;
cluster_dim_idx = cta_m_in_cluster;
}
else {
uint64_t block_idx_m = work_tile_info.M_idx;
uint64_t block_idx_n = (work_tile_info.N_idx - blockIdx.y) / gridDim.y;
uint64_t block_idx_n = (work_tile_info.N_idx - cta_n_in_cluster) / cute::size<1>(params.cluster_shape_);
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m;
cluster_dim_idx = blockIdx.y;
cluster_dim_idx = cta_n_in_cluster;
}
uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim;
@ -646,7 +606,7 @@ public:
get_workspace_size(
Arguments const& args,
ProblemShape problem_shape,
KernelHardwareInfo const& hw_info,
KernelHardwareInfo const& hw_info,
uint32_t mma_warp_groups) {
int barrier_workspace_size = 0;
@ -715,7 +675,7 @@ private:
// Construct a layout for the indexed tensor. The main purpose of this new layout is to
// override the k extent to support cases in which the split computes a number of iterations
// not equal to total_tile_k_iter / splits. A common example of this is in stream-K is when a
// not equal to total_k_tiles / splits. A common example of this is in stream-K is when a
// unit computes the final 20 of the total 32 k iterations of the output tile. In this case,
// set splits = 32 and the split index (K_idx) to 11. The zipped divide above results in each
// of the splits computing only one k iteration.
@ -728,12 +688,13 @@ private:
// Returns the number of stream-K tiles that will be computed amongst `output_tiles` total
// output tiles on a device with `ctas_per_wave` CTAs in each wave.
static uint32_t
get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave) {
get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave, uint32_t k_tiles_per_output_tile) {
uint32_t full_waves = static_cast<uint32_t>(output_tiles / ctas_per_wave);
uint32_t total_waves = static_cast<uint32_t>((output_tiles + ctas_per_wave - 1) / ctas_per_wave);
if (full_waves == total_waves) {
// No quantization. All tiles will be data-parallel tiles.
if (full_waves == total_waves || k_tiles_per_output_tile == 1) {
// All tiles will be data-parallel tiles if there is either no quantization
// or if there is no work to be split.
return 0;
}
@ -811,9 +772,12 @@ private:
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(TileShape{}) - 1) /
cute::size<2>(TileShape{});
dim3 grid = get_grid_shape(problem_shape_mnkl, TileShape{}, cluster_shape, {0, sm_count}, args);
uint64_t ctas_per_wave = grid.x * grid.y;
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave);
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile);
barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups);
reduction_workspace_size = get_reduction_workspace_size<ElementAccumulator>(sk_tiles);
@ -829,10 +793,10 @@ private:
uint32_t blocks_l,
ClusterShape cluster_shape,
uint32_t splits,
uint32_t k_iter_per_tile,
uint32_t k_tiles_per_output_tile,
void* reduction_workspace) {
uint32_t big_units = k_iter_per_tile % splits;
uint32_t big_units = k_tiles_per_output_tile % splits;
return {
underlying_params.divmod_cluster_shape_major_,
@ -845,7 +809,7 @@ private:
underlying_params.raster_order_,
cluster_shape,
splits,
k_iter_per_tile,
k_tiles_per_output_tile,
big_units,
reduction_workspace
};
@ -855,8 +819,12 @@ private:
// is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining
// iterations) is used to find the next tile in the current work unit.
CUTLASS_DEVICE
void
set_stream_k_work(uint64_t linear_idx, WorkTileInfo& work_tile_info, bool new_unit) const {
static void
set_stream_k_work(
Params const& params,
uint64_t linear_idx,
WorkTileInfo& work_tile_info,
bool new_unit) {
// In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K
// threadblock individually. For the most part, the set of K iterations corresponding to stream-K
// work was divided amongst stream-K threadblocks, and a threadblock determined which tile
@ -872,15 +840,15 @@ private:
//
// To do so, we divide up the linearized stream-K units into clusters and share the same K
// offsets for work within clusters.
auto cluster_linear_work_idx = linear_idx / size(scheduler_params.cluster_shape_);
auto cluster_linear_work_idx = linear_idx / size(params.cluster_shape_);
// Determine the starting k iteration computed by this stream-K work unit
uint32_t unit_iter_start = scheduler_params.k_iter_per_sk_unit_ * cluster_linear_work_idx;
uint32_t unit_iter_start = params.k_tiles_per_sk_unit_ * cluster_linear_work_idx;
// Adjust the starting position and number of k iterations for "big units," which
// compute one extra iteration. These are the first big_units_ units in the
// linearized ID space.
bool is_big_unit = cluster_linear_work_idx < scheduler_params.big_units_;
bool is_big_unit = cluster_linear_work_idx < params.big_units_;
if (is_big_unit) {
// Since the "big units" are the first units in the linearized ID space, each
// of the units preceding this big unit computed one extra iteration. Thus,
@ -889,16 +857,16 @@ private:
unit_iter_start += cluster_linear_work_idx;
} else {
// Increment by one for each of the big clusters (since all big units precede this unit)
unit_iter_start += scheduler_params.big_units_;
unit_iter_start += params.big_units_;
}
uint32_t unit_iters;
if (new_unit) {
unit_iters = scheduler_params.k_iter_per_sk_unit_;
unit_iters = params.k_tiles_per_sk_unit_;
// Only adjust iteration count for big unit if we are initializing this
// work unit. For existing work units, the extra iteration for big units
// has already been accounted for in k_iter_reamaining
// has already been accounted for in k_tiles_reamaining
if (is_big_unit) {
++unit_iters;
}
@ -917,22 +885,21 @@ private:
// for them to be computed later, so as to reduce the likelihood of blocking
// on other work.
uint32_t unit_iter_end = unit_iter_start + unit_iters - 1;
uint32_t true_tile_id = unit_iter_end / scheduler_params.k_iter_per_tile_;
uint32_t true_tile_iter_start = true_tile_id * scheduler_params.k_iter_per_tile_;
uint32_t true_tile_iter_end = true_tile_iter_start + scheduler_params.k_iter_per_tile_;
uint32_t true_tile_id = unit_iter_end / params.k_tiles_per_output_tile_;
uint32_t true_tile_iter_start = true_tile_id * params.k_tiles_per_output_tile_;
uint32_t true_tile_iter_end = true_tile_iter_start + params.k_tiles_per_output_tile_;
// Bring the linearized tile ID back into the space of tiles, rather than clusters
true_tile_id *= size(scheduler_params.cluster_shape_);
true_tile_id *= size(params.cluster_shape_);
auto cluster_dim0 = cute::size<0>(scheduler_params.cluster_shape_);
auto cluster_dim1 = cute::size<1>(scheduler_params.cluster_shape_);
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
// The final linearized tile ID is in units of the cluster dimension over which we rasterize.
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
true_tile_id += (blockIdx.y % cluster_dim1) * cluster_dim0;
if (params.raster_order_ == RasterOrder::AlongN) {
true_tile_id += cta_n_in_cluster * cute::size<0>(params.cluster_shape_);
}
else {
true_tile_id += (blockIdx.x % cluster_dim0) * cluster_dim1;
true_tile_id += cta_m_in_cluster * cute::size<1>(params.cluster_shape_);
}
// The unit's starting k iteration in the current tile is either the starting
@ -948,19 +915,18 @@ private:
uint32_t tile_iters = tile_iter_end - tile_iter_start;
uint64_t work_idx_l, remainder;
scheduler_params.divmod_batch_(work_idx_l, remainder, true_tile_id);
params.divmod_batch_(work_idx_l, remainder, true_tile_id);
uint64_t cta_per_grid_dim, dontcare;
scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
cta_per_grid_dim,
scheduler_params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_minor_,
scheduler_params.divmod_cluster_blk_major_,
scheduler_params.log_swizzle_size_,
scheduler_params.raster_order_);
params.divmod_cluster_shape_major_,
params.divmod_cluster_shape_minor_,
params.divmod_cluster_blk_major_,
params.log_swizzle_size_,
params.raster_order_);
//
// Update the work_tile_info
@ -971,11 +937,11 @@ private:
work_tile_info.N_idx = work_idx_n;
work_tile_info.L_idx = static_cast<int32_t>(work_idx_l);
// Set the k offset to be the starting k iteration for this tile
// Set the k offset to be the starting k tile for this output tile
work_tile_info.K_idx = static_cast<int32_t>(tile_iter_start - true_tile_iter_start);
// Set the split count to be the number of k iterations in the tile
work_tile_info.splits = scheduler_params.k_iter_per_tile_;
// Set the split count to be the number of k tiles in the output tile
work_tile_info.splits = params.k_tiles_per_output_tile_;
// Any checks for invalid work units should be done prior to this call
work_tile_info.is_valid_tile = true;
@ -987,6 +953,89 @@ private:
// the output tile in question
work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end);
}
// Returns a WorkTileInfo to be computed for either the data-parallel or split-K
// work unit identified by the provided linear ID.
CUTLASS_DEVICE
static WorkTileInfo
set_non_stream_k_work(uint64_t linear_idx, Params const& params, bool is_split_k) {
// The linearized ID space is in terms of work units, rather than tiles. However,
// to compute the correct block offset for a data-parallel tile, we must convert
// the current ID to the data-parallel tile it corresponds to. Each data-parallel
// unit maps to a single data-parallel tile, but each stream-K unit can map to more
// than one tile. Thus, we must offset the work-unit ID among the data-parallel units
// by the total number of output tiles that will be computed by stream-K units.
//
// The logic below also works for the split-K case, in which sk_units_ and sk_tiles_
// are each 0.
uint64_t linear_work_idx = linear_idx - params.sk_units_ + params.sk_tiles_;
// Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices
uint64_t work_idx_l, remainder;
params.divmod_batch_(work_idx_l, remainder, linear_work_idx);
uint64_t work_idx_k = 0;
if (is_split_k) {
params.divmod_k_(work_idx_k, remainder, remainder);
}
uint64_t cta_per_grid_dim, dontcare;
params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
cta_per_grid_dim,
params.divmod_cluster_shape_major_,
params.divmod_cluster_shape_minor_,
params.divmod_cluster_blk_major_,
params.log_swizzle_size_,
params.raster_order_);
bool is_final_split = (work_idx_k == params.splits_ - 1);
uint32_t k_tiles = params.k_tiles_per_output_tile_;
if (is_split_k) {
// Determine the number of iterations and starting iteration of this split.
// Doing so requires accounting for residual iterations, which are handled
// by the first big_units_ splits (with big_units_ = tiles % sm_count).
// Offsets for "normal" units. No additional k iterations are performed,
// and big_units_ "big" units preceded us, each of which performed one
// additional iteration. Thus, we must increase our split starting offset
// by big_units_.
int additional_k_tiles = 0;
int split_start_offset = params.big_units_;
if (work_idx_k < params.big_units_) {
// Offsets for "big" units. One additional k iteration is performed,
// and each split preceding us was a big unit, so we must increase
// our split starting offset by our split ID (work_idx_k).
additional_k_tiles = 1;
split_start_offset = work_idx_k;
}
// Set up k iteration count and split starting iteration assuming the
// iteration space is evenly split.
k_tiles /= params.splits_;
work_idx_k *= k_tiles;
// Apply any fixup needed to handle residuals
work_idx_k += split_start_offset;
k_tiles += additional_k_tiles;
}
return {
work_idx_m,
work_idx_n,
static_cast<int32_t>(work_idx_k),
static_cast<int32_t>(work_idx_l),
true,
params.k_tiles_per_output_tile_,
k_tiles,
k_tiles, // remaining iterations
is_final_split
};
}
};
} // namespace cutlass::gemm::kernel::detail

View File

@ -28,7 +28,7 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
/*!
\file
\brief Boost-like numeric conversion operator for CUTLASS numeric types
*/
@ -55,7 +55,7 @@ enum class FloatRoundStyle {
round_indeterminate, ///< rounding mode unknown
round_toward_zero, ///< round toward zero
round_to_nearest, ///< round to nearest even
round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type
round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type
round_toward_infinity, ///< round toward infinity
round_toward_neg_infinity, ///< round toward negative infinity
round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero
@ -561,7 +561,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_to_nearest> {
// Note, the following is intentionally commented out. TF32
// does not define the low order bits, so they may be left in
// an undefined state.
// an undefined state.
//
// By not truncating these bit explicitly, we avoid an extra logical
// operation.
@ -657,7 +657,7 @@ template <
struct NumericConverterFastF32 {
// result_type holds big tfloat32_t at idx(0) and small tfloat32_t at idx(1)
using result_type = Array<tfloat32_t, 2>;
using result_type = Array<tfloat32_t, 2>;
// source data type
using source_type = float;
@ -708,7 +708,7 @@ struct NumericConverterClamp {
NumericConverter<result_type, source_type> convert_op;
result_type const kClamp_max = platform::numeric_limits<result_type>::max();
result_type const kClamp_min = platform::numeric_limits<result_type>::lowest();
if (s < (source_type)kClamp_min)
if (s < (source_type)kClamp_min)
return kClamp_min;
if (s > (source_type)kClamp_max)
return kClamp_max;
@ -848,7 +848,7 @@ struct NumericArrayConverter<half_t, float, 2, FloatRoundStyle::round_to_nearest
result[0] = convert_(source[0]);
result[1] = convert_(source[1]);
#endif
return result;
}
@ -878,7 +878,7 @@ struct NumericArrayConverter<float, half_t, 2, Round> {
result[0] = convert_(source[0]);
result[1] = convert_(source[1]);
#endif
return result;
}
@ -1044,7 +1044,7 @@ struct NumericArrayConverter<bfloat16_t, float, N, Round> {
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers
// Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
@ -1066,7 +1066,7 @@ struct NumericArrayConverter<int8_t, int, 1, Round> {
result_type result;
result[0] = convert_element_(source[0]);
return result;
}
@ -1189,7 +1189,7 @@ struct NumericArrayConverter<uint8_t, int, 1, Round> {
result_type result;
result[0] = convert_element_(source[0]);
return result;
}

View File

@ -215,10 +215,17 @@ class GemmArguments2x(ArgumentBase):
else:
self.batch_count = 1
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
if "batch_strides" in kwargs:
self.batched_stride_A = kwargs["batch_strides"]["A"]
self.batched_stride_B = kwargs["batch_strides"]["B"]
self.batched_stride_C = kwargs["batch_strides"]["C"]
self.batched_stride_D = kwargs["batch_strides"]["D"]
else:
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
if self.bias:
self.batched_stride_C = self.problem_size.n()

View File

@ -132,9 +132,9 @@ class KernelsForDataType:
"""
# Determine the leading dimension of the shape
if layout == cutlass.LayoutType.ColumnMajor:
ld = shape[0]
ld = shape[-2]
elif layout == cutlass.LayoutType.RowMajor:
ld = shape[1]
ld = shape[-1]
elif layout == cutlass.LayoutType.TensorNHWC:
ld = shape[-1]
else:

View File

@ -114,6 +114,8 @@
args.sync()
"""
from math import prod
import cutlass_bindings
import cutlass
@ -442,6 +444,113 @@ class Gemm(OperationBase):
compiler.add_module([self.operation,])
return self.operation
def _verify_rank(self, tensor):
"""
Verifies that ``tensor`` has rank greater than 1
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
"""
if len(tensor.shape) < 2:
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
def _get_batch_count(self, A, B, C, D) -> int:
"""
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
tensors match in batch size. Presence of a batch dimension is detected by one of the
tensors being rank 3. If a batch dimension is present, it must be present in one of
operands A, B, or C (but need not be in all), and must be present in D.
:param A: tensor A
:type A: numpy/cupy/torch array/tensor object
:param B: tensor B
:type B: numpy/cupy/torch array/tensor object
:param C: tensor C
:type C: numpy/cupy/torch array/tensor object
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple of batch count dimensions
:rtype: tuple
"""
A_batch = A.shape[:-2] if len(A.shape) > 2 else tuple()
B_batch = B.shape[:-2] if len(B.shape) > 2 else tuple()
C_batch = C.shape[:-2] if len(C.shape) > 2 else tuple()
D_batch = D.shape[:-2] if len(D.shape) > 2 else tuple()
if len(D_batch) > 0 and D_batch not in [A_batch, B_batch, C_batch]:
raise Exception(f"Batch count in D must be present in one of operands A, B, and C. "
f"Batch counts are: A={A_batch}, B={B_batch}, C={C_batch}, D={D_batch}")
for batch_shape in [A_batch, B_batch, C_batch]:
if len(batch_shape) > 0 and batch_shape != D_batch:
raise Exception(f"Batch count for all other operands must either match that of D or be zero."
f"Received batch shape of {batch_shape}, which does not match that of D of {D_batch}.")
return D_batch
def _get_batch_stride(self, tensor) -> int:
"""
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
:param tensor: tensor object to process
:type tensor: numpy/cupy/torch array/tensor object
:return: stride between each matrix in the batch
:rtype: int
"""
if len(tensor.shape) > 2:
return tensor.shape[-2] * tensor.shape[-1]
else:
return 0
def _get_problem_args(self, A, B, C, D) -> tuple:
"""
Returns the problem size and GEMM universal mode to use for the
given operands.
:param A: tensor A
:type A: numpy/cupy/torch array/tensor object
:param B: tensor B
:type B: numpy/cupy/torch array/tensor object
:param C: tensor C
:type C: numpy/cupy/torch array/tensor object
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple containing the problem size (cutlass_bindings.gemm.GemmCoord), the GEMM mode (cutlass_bindings.gemm.Mode), and the batch count (int)
:rtype: tuple
"""
M, K = A.shape[-2:]
N = B.shape[-1]
mode = cutlass_bindings.gemm.Mode.Gemm
batch_count = self._get_batch_count(A, B, C, D)
returned_batch_count = prod(batch_count) if len(batch_count) > 0 else 1
# If we are running a batched GEMM in which there is a nonzero batch stride
# only for A, then we can fold the batched dimension of A into the M dimension
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
# and C are row major. A similar operation can be performed if only B has a nonzero
# batch dimension
if len(batch_count) > 0:
A_row = self._layout_a == cutlass.LayoutType.RowMajor
B_row = self._layout_b == cutlass.LayoutType.RowMajor
C_row = self._layout_c == cutlass.LayoutType.RowMajor
batched = lambda x : len(x.shape) == 2 + len(batch_count)
if batched(A) and not batched(B) and batched(C) and A_row and C_row:
M *= prod(batch_count)
returned_batch_count = 1
elif not batched(A) and batched(B) and batched(C) and not B_row and not C_row:
N *= prod(batch_count)
returned_batch_count = 1
else:
mode = cutlass_bindings.gemm.Mode.Batched
return cutlass_bindings.gemm.GemmCoord(M, N, K), mode, returned_batch_count
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
"""
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
@ -461,8 +570,7 @@ class Gemm(OperationBase):
f'layout of ({ref_type}, {ref_layout}).')
def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, batch_count: int = 1,
sync: bool = True, print_module: bool = False) -> GemmArguments:
alpha=None, beta=None, sync: bool = True, print_module: bool = False) -> GemmArguments:
"""
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
@ -481,8 +589,6 @@ class Gemm(OperationBase):
:param D: tensor representing data type and layout of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param batch_count: number of GEMMs in the batch
:type batch_count: int
:param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool
:param print_module: whether to print the emitted C++ code
@ -491,9 +597,6 @@ class Gemm(OperationBase):
:return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmArguments
"""
if batch_count < 1:
raise Exception(f"Invalid batch count {batch_count}. Value must be an integer >= 1.")
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
@ -501,20 +604,31 @@ class Gemm(OperationBase):
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
self._verify_rank(A)
self._verify_rank(B)
self._verify_rank(C)
self._verify_rank(D)
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a)
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b)
alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c)
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module)
problem_size = cutlass_bindings.gemm.GemmCoord(A.shape[0], B.shape[1], A.shape[1])
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
if batch_count == 1:
mode = cutlass_bindings.gemm.Mode.Gemm
if mode == cutlass_bindings.gemm.Mode.Gemm or batch_count == 1:
kwargs = {'split_k_slices': 1}
else:
mode = cutlass_bindings.gemm.Mode.Batched
kwargs = {'batch': batch_count}
kwargs = {
'batch': batch_count,
'batch_strides': {
'A': self._get_batch_stride(A),
'B': self._get_batch_stride(B),
'C': self._get_batch_stride(C),
'D': self._get_batch_stride(D)
}
}
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,

View File

@ -0,0 +1,139 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level tests for running batched GEMMs
"""
from functools import partial
from math import prod
import cutlass
import logging
import torch
import unittest
from cutlass.backend.test.utils import LayoutCombination, add_test_gemm
from cutlass.backend.utils.device import device_cc
cutlass.set_log_level(logging.WARNING)
torch.manual_seed(2023)
def pytorch_reference(A, B, C, alpha, beta):
# Get the batch count. Assume that any of A, B, and C
# with a batch dimension ahve matching batch count. Thus,
# we break out of the loop once we have found the first
# tensor containing a batch dimension.
batch_count = (1,)
for tensor in [A, B, C]:
if len(tensor.shape) > 2:
batch_count = tensor.shape[:-2]
break
int_batch_count = prod(batch_count)
def add_batch(tensor):
if len(tensor.shape) == 2:
return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1)
else:
return tensor.reshape(-1, tensor.size(-2), tensor.size(-1))
# Reshape tensors to have batch dimension
A = add_batch(A)
B = add_batch(B)
C = add_batch(C)
ret = (torch.bmm(A, B) * alpha) + (C * beta)
reshape_vals = batch_count + C.shape[-2:]
return ret.reshape(*reshape_vals)
def initialize(rows, cols, batch):
tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half()
if len(batch) > 0 and prod(batch) > 1:
reshape_vals = batch + (rows, cols)
return tensor.reshape(*reshape_vals)
else:
return tensor.reshape(rows, cols)
class GemmF16Batched(unittest.TestCase):
def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool):
M = 512
N = 256
K = 128
alpha = 1.
beta = 2.
A = initialize(M, K, batch_count if batch_A else (1,))
B = initialize(K, N, batch_count if batch_B else (1,))
C = initialize(M, N, batch_count if batch_C else (1,))
D = initialize(M, N, batch_count)
plan = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass.DataType.f32)
plan.run(A, B, C, D, alpha, beta)
reference = pytorch_reference(A, B, C, alpha, beta)
assert reference.equal(D)
def test_batched_ABC(self):
self.run_batched((3,), True, True, True)
self.run_batched((2, 3), True, True, True)
def test_batched_AB(self):
self.run_batched((3,), True, True, False)
self.run_batched((2, 3), True, True, False)
def test_batched_AC(self):
self.run_batched((3,), True, False, True)
self.run_batched((2, 3), True, False, True)
def test_batched_BC(self):
self.run_batched((3,), False, True, True)
self.run_batched((2, 3), False, True, True)
def test_batched_A(self):
self.run_batched((3,), True, False, False)
self.run_batched((2, 3), True, False, False)
def test_batched_B(self):
self.run_batched((3,), False, True, False)
self.run_batched((2, 3), False, True, False)
def test_batched_C(self):
self.run_batched((3,), False, False, True)
self.run_batched((2, 3), False, False, True)
if __name__ == '__main__':
unittest.main()

View File

@ -61,7 +61,7 @@ __global__ void convert(
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Destination, typename Source, int Count>
void run_test() {
void run_test(const char dest_name[], const char source_name[]) {
const int kN = Count;
dim3 grid(1, 1);
@ -84,7 +84,10 @@ void run_test() {
destination.sync_host();
for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]))
<< "Destination type: " << dest_name
<< ", Source type: " << source_name
<< ", Count: " << Count;
}
}
@ -97,15 +100,19 @@ void run_test() {
TEST(NumericConversion, f32_to_f16_rn) {
int const kN = 1;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f32x8_to_f16x8_rn) {
int const kN = 8;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -113,15 +120,19 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) {
TEST(NumericConversion, f16_to_f32_rn) {
int const kN = 1;
using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f16x8_to_f32x8_rn) {
int const kN = 8;
using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -129,86 +140,109 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) {
TEST(NumericConversion, f32_to_fe4m3_rn) {
int const kN = 1;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f32_to_fe4m3_rn_array) {
int const kN = 27;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f32_to_fe5m2_rn) {
int const kN = 1;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f32_to_fe5m2_rn_array) {
int const kN = 27;
using Source = float;
const char source_name[] = "float";
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f16_to_fe4m3_rn) {
int const kN = 1;
using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f16_to_fe4m3_rn_array) {
int const kN = 27;
using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f16_to_fe5m2_rn) {
int const kN = 1;
using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, f16_to_fe5m2_rn_array) {
int const kN = 27;
using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, bf16_to_fe4m3_rn) {
int const kN = 1;
using Source = cutlass::bfloat16_t;
const char source_name[] = "bfloat16_t";
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, bf16_to_fe4m3_rn_array) {
int const kN = 27;
using Source = cutlass::bfloat16_t;
const char source_name[] = "bfloat16_t";
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, bf16_to_fe5m2_rn) {
int const kN = 1;
using Source = cutlass::bfloat16_t;
const char source_name[] = "bfloat16_t";
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
int const kN = 27;
using Source = cutlass::bfloat16_t;
const char source_name[] = "bfloat16_t";
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -216,36 +250,46 @@ TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
TEST(NumericConversion, fe4m3_to_fe5m2_rn) {
int const kN = 1;
using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe4m3_to_fe5m2_array) {
int const kN = 27;
using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe5m2_to_fe4m3_rn) {
int const kN = 1;
using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe5m2_to_fe4m3_array) {
int const kN = 27;
using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe4m3_to_f32_rn) {
int const kN = 1;
using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -254,78 +298,100 @@ TEST(NumericConversion, f32x8_to_s8x8_rn) {
int const kN = 8;
using Source = float;
const char source_name[] = "float";
using Destination = int8_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "int8_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe4m3_to_f32_array) {
int const kN = 27;
using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe5m2_to_f32_array) {
int const kN = 27;
using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe4m3_to_f16_rn) {
int const kN = 1;
using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe4m3_to_f16_array) {
int const kN = 27;
using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe5m2_to_f16_rn) {
int const kN = 1;
using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe5m2_to_f16_array) {
int const kN = 27;
using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe4m3_to_bf16_rn) {
int const kN = 1;
using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "bfloat16_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe4m3_to_bf16_array) {
int const kN = 27;
using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "bfloat16_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe5m2_to_bf16_rn) {
int const kN = 1;
using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "bfloat16_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
TEST(NumericConversion, fe5m2_to_bf16_array) {
int const kN = 27;
using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>();
const char dest_name[] = "bfloat16_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -36,6 +36,7 @@ cutlass_test_unit_add_executable(
compare.cpp
complement.cpp
composition.cpp
constant_arithmetic.cpp
core_unit.cpp
inverse_left.cpp
inverse_right.cpp

View File

@ -0,0 +1,106 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "cutlass_unit_test.h"
#include <cutlass/trace.h>
#include <cute/swizzle.hpp>
TEST(CuTe_core, ConstantArithmetic) {
using namespace cute;
constexpr cute::integral_constant<uint32_t, 0> uzero{};
// This extra test exists historically as part of the diagnosis
// of a possible Clang 14 bug. However, it's a nice test for
// cute::integral_constant's arithmetic operators, so it's saved here.
// It also demonstrates how to work with cute::integral_constant
// and lambda captures. Microsoft Visual Studio ("MSVC") tends to
// disagree with other compilers about the meaning of decltype
// for variables captured by reference. MSVC and GCC 8.3.0
// also tend to disagree with other compilers (and other GCC versions)
// about whether expressions involving such variables
// are constant expressions.
//
// A typical CuTe idiom is to do lambda captures by reference [&].
// This test changes them to capture by value, except for
// the innermost lambda's capture of S1, which is by reference.
// The point is to show that MSVC and GCC 8 have issues with this
// that other compilers do not. For example,
//
// 1. MSVC needs remove_cvref_t around decltype(S1)
// in order to access decltype(S1)::value, and
// 2. MSVC and GCC 8.3.0 both report a build error with S1()
// (that is, calling operator() on S1, which returns the
// same thing as S1.value).
//
// The reason for (2) is that neither compiler thinks
// that S1() is a constant expression.
//
// This leaves S1.value as the most concise portable expression
// for the "value" member of a cute::integral_constant.
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero](auto S0) {
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0](auto F0) {
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0](auto S1) {
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0,&S1](auto F1) {
static_assert((decltype(S0)::value & decltype(F0)::value) == decltype(S0 & F0)::value);
// Using S1.value means you don't have to use remove_cvref_t
// with a captured-by-reference variable.
static_assert((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) == decltype(S1 & F1)::value);
static_assert((S1.value & decltype(F1)::value) == decltype(S1 & F1)::value);
// S1() _should_ work, but does not with Visual Studio 2022,
// which emits C2131 ("expression did not evaluate to a constant").
// It also does not with GCC 8.3.0, which emits an error with messages
// "non-constant condition for static assertion" and
// "'this' is not a constant expression."
//
//static_assert((S1() & decltype(F1)::value) == decltype(S1 & F1)::value);
static_assert(decltype((S0 & F0) != uzero)::value == ((decltype(S0)::value & decltype(F0)::value) != 0));
static_assert(decltype((S1 & F1) != uzero)::value == ((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0));
static_assert(decltype((S1 & F1) != uzero)::value == ((S1.value & decltype(F1)::value) != 0));
constexpr bool left = decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value;
constexpr bool right =
((decltype(S0)::value & decltype(F0)::value) != 0) ||
((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0);
constexpr bool right2 =
((decltype(S0)::value & decltype(F0)::value) != 0) ||
((S1.value & decltype(F1)::value) != 0);
static_assert(right == right2);
static_assert(left == right);
constexpr bool left2 = decltype((S0 & F0) != uzero)::value || decltype((S1 & F1) != uzero)::value;
static_assert(left == left2);
});
});
});
});
}

View File

@ -31,9 +31,58 @@
#include "cutlass_unit_test.h"
// C<uint32_t(something)>::value_type is not uint32_t for GCC 7.5.0.
// This test is thus disabled for GCC < 8.
#if defined(__GNUC__) && (__GNUC__ < 8)
#include <cutlass/trace.h>
#include <cute/swizzle.hpp>
namespace { // (anonymous)
// This function exists to work around a Clang 14 issue, in which
// the compiler tries to instantiate code that lives inside the
// "else" branch of an "if constexpr," even when the "else" branch
// is false. That triggers a spurious static_assert in MixedBits.
// The work-around is to make the body of the "else" branch a
// function, rather than leaving it in line.
//
// Some compilers strangely deduce the first two terms of
// make_integer_sequence<uint32_t, 8> as C<false> and C<true>, and
// the remaining terms as C<2>, C<3>, etc. Making this function take
// cute::integral_constant<uint32_t, S0_value>, etc. doesn't work
// with those compilers.
template<class S0_type, S0_type S0_value,
class F0_type, F0_type F0_value,
class S1_type, S1_type S1_value,
class F1_type, F1_type F1_value>
void clang14_workaround(cute::integral_constant<S0_type, S0_value>,
cute::integral_constant<F0_type, F0_value>,
cute::integral_constant<S1_type, S1_value>,
cute::integral_constant<F1_type, F1_value>)
{
constexpr cute::C<static_cast<uint32_t>(S0_value)> S0{};
constexpr cute::C<static_cast<uint32_t>(F0_value)> F0{};
constexpr cute::C<static_cast<uint32_t>(S1_value)> S1{};
constexpr cute::C<static_cast<uint32_t>(F1_value)> F1{};
for (uint32_t d0 = 0; d0 < 8; ++d0) {
if ((d0 & F0) != d0) { continue; } // Skip repeats
for (uint32_t d1 = 0; d1 < 8; ++d1) {
if ((d1 & F1) != d1) { continue; } // Skip repeats
auto m0 = make_mixed_bits(S0, d0, F0);
auto m1 = make_mixed_bits(S1, d1, F1);
//print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n");
EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1));
//print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n");
EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1));
//print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n");
EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1));
}
}
}
} // namespace (anonymous)
TEST(CuTe_core, MixedBits) {
using namespace cute;
@ -48,23 +97,21 @@ TEST(CuTe_core, MixedBits) {
} else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) {
return;
} else {
for (uint32_t d0 = 0; d0 < 8; ++d0) {
if ((d0 & F0) != d0) { continue; } // Skip repeats
for (uint32_t d1 = 0; d1 < 8; ++d1) {
if ((d1 & F1) != d1) { continue; } // Skip repeats
auto m0 = make_mixed_bits(S0, d0, F0);
auto m1 = make_mixed_bits(S1, d1, F1);
//print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n");
EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1));
//print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n");
EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1));
//print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n");
EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1));
}
}
clang14_workaround(S0, F0, S1, F1);
}
});
});
});
});
}
TEST(CuTe_core, MakeIntegerSequence) {
cute::for_each(cute::make_integer_sequence<uint32_t, 8>{}, [](auto c) {
using c_type = decltype(c);
constexpr auto c_value = c_type::value;
using expected_type = cute::integral_constant<uint32_t, c_value>;
static_assert(cute::is_same_v<c_type, expected_type>);
});
}
#endif // defined(__GNUC__) && (__GNUC__ < 8)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without

View File

@ -32,6 +32,7 @@
\brief Tests that the stream-K scheduler covers the entire problem space.
*/
#include "cutlass/cluster_launch.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp"
#include "cutlass/util/device_memory.h"
@ -39,6 +40,10 @@
#include "../../common/cutlass_unit_test.h"
// Grids are launched with clusters enabled in these tests,
// so the CTK version must support cluster launching.
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
using namespace cute;
using ProblemShape_MNKL = Shape<int, int, int, int>;
@ -60,7 +65,7 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape
while (work_tile_info.is_valid_tile) {
// Increment counters to indicate coverage
auto tile_idx = Scheduler::output_tile_index(params, work_tile_info);
auto offset = tile_idx * params.k_iter_per_tile_ + work_tile_info.K_idx;
auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx;
for (auto i = 0; i < work_tile_info.k_tile_count; ++i) {
// Use atomicAdd because the visit counters are shared by multiple thread blocks.
// While having more than one block increment the same counter indicates failure,
@ -103,7 +108,7 @@ test_scheduler(
// Allocate counters indicating the number of times each k iteration of each output tile has been visited
auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
auto total_counters = blk_m * blk_n * blk_l * params.k_iter_per_tile_;
auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_;
cutlass::DeviceAllocation<int> visit_counters(total_counters);
// Initialize counters to zero
@ -118,12 +123,55 @@ test_scheduler(
// Set up the grid for the problem
dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args);
// Set up cluster and cluster launch. This is needed even for this simple kernel because
// the SM90 scheduler needs to be able to query the CTA id within a cluster, which requires
// explicitly launching with clusters.
dim3 cluster{
static_cast<uint32_t>(cute::get<0>(ClusterShape{})),
static_cast<uint32_t>(cute::get<1>(ClusterShape{})),
static_cast<uint32_t>(cute::get<2>(ClusterShape{}))
};
cudaLaunchConfig_t launch_config;
launch_config.gridDim = grid;
launch_config.blockDim = {1, 1, 1};
launch_config.dynamicSmemBytes = 0;
launch_config.stream = NULL;
cudaLaunchAttribute launch_attribute[1];
launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
launch_attribute[0].val.clusterDim.x = cluster.x;
launch_attribute[0].val.clusterDim.y = cluster.y;
launch_attribute[0].val.clusterDim.z = cluster.z;
launch_config.attrs = launch_attribute;
launch_config.numAttrs = 1;
void const* kernel = (void const*) run_scheduler<Scheduler, TileShape, ClusterShape>;
int* counters_ptr = visit_counters.get();
void* kernel_params[] = {
&counters_ptr,
&params,
&tile_shape,
&cluster_shape,
&problem_shape_mnkl
};
// Run the scheduler to completion and log visits to each k iteration
run_scheduler<Scheduler, TileShape, ClusterShape><<<grid, 1>>>(
visit_counters.get(), params, tile_shape, cluster_shape, problem_shape_mnkl);
err = cudaLaunchKernelExC(&launch_config, kernel, kernel_params);
if (err != cudaSuccess) {
std::cerr << __FILE__ << ":" << __LINE__
<< " cudaLaunchKernelExC failed with error: "
<< cudaGetErrorString(err) << std::endl;
return false;
}
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
std::cerr << __FILE__ << ":" << __LINE__ << " scheduler kernel failed with error: " << cudaGetErrorString(err) << std::endl;
std::cerr << __FILE__ << ":" << __LINE__
<< " scheduler kernel failed with error: "
<< cudaGetErrorString(err) << std::endl;
return false;
}
@ -143,11 +191,11 @@ test_scheduler(
<< " and grid size " << grid.x << "x"
<< grid.y << "x" << grid.z
<< " splits=" << params.splits_
<< " k_iter=" << params.k_iter_per_tile_
<< " k_iter=" << params.k_tiles_per_output_tile_
<< " big_units=" << params.big_units_
<< " sk_tiles=" << params.sk_tiles_
<< " sk_units=" << params.sk_units_
<< " k_iter_per_sk_unit=" << params.k_iter_per_sk_unit_ << std::endl;
<< " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl;
std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl;
return false;
}
@ -274,4 +322,6 @@ TEST(SM90_Device_Gemm_stream_k_scheduler, 128x128x64_2x1x1) {
EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114));
}
#endif // defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -179,7 +179,7 @@ class GemmOperation:
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
if self.arch >= 90:
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}{t}"
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{t}{k}{e}"
return kernel_name_template.format(
p = self.prefix,
ar = self.arch,
@ -194,9 +194,9 @@ class GemmOperation:
l = self.tile_description.stages,
s = self.layout_name_3x(),
al = str(max(self.A.alignment, self.B.alignment)),
t = TileSchedulerSuffixes[self.tile_scheduler],
k = self.kernel_schedule_name_3x(),
e = self.epilogue_schedule_name_3x(),
t = TileSchedulerSuffixes[self.tile_scheduler])
e = self.epilogue_schedule_name_3x())
else:
threadblock = self.tile_description.procedural_name()
return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
@ -661,8 +661,7 @@ using ${operation_name}_mainloop =
${element_accumulator},
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
cutlass::gemm::collective::StageCountAutoCarveout<
sizeof(typename ${operation_name}_epilogue::SharedStorage)>,
${stages},
${kernel_schedule}
>::CollectiveOp;
@ -697,7 +696,7 @@ ${compile_guard_end}
if operation.tile_description.stages > 0:
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
else:
stage_count_string = "cutlass::gemm::collective::StageCountAuto"
stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage)>"
warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)]
instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \

View File

@ -4218,6 +4218,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
layout[2][1] = 8
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
# persistent kernels with TMA epilogues
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed,