Updates for 3.2 release (#1065)

This commit is contained in:
ANIKET SHIVAM 2023-08-25 17:05:46 -10:00 committed by GitHub
parent 27de343535
commit a88c41cf8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 904 additions and 257 deletions

View File

@ -2,10 +2,14 @@
## 2023 ## 2023
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, and Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023. - ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023. - ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023. - ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
## 2022 ## 2022

View File

@ -71,7 +71,7 @@ struct Options {
/// Prints the usage statement. /// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const { std::ostream & print_usage(std::ostream &out) const {
out << "52_fp8_hopper_warp_specialized_gemm\n\n" out << "54_fp8_hopper_warp_specialized_gemm\n\n"
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n" << " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n" << "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n" << " --help If specified, displays this usage statement\n\n"
@ -93,7 +93,7 @@ struct Options {
out out
<< "\n\nExamples:\n\n" << "\n\nExamples:\n\n"
<< "$ " << "52_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; << "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out; return out;
} }

View File

@ -54,6 +54,13 @@ struct SyncthreadsSync {
} }
}; };
struct SyncwarpSync {
CUTLASS_DEVICE
static void sync() {
__syncwarp();
}
};
template < template <
int ThreadCount, int ThreadCount,
int BarrierId int BarrierId
@ -311,6 +318,60 @@ private:
} }
}; };
/////////////////////////////////////////////////////////////////////////////////////////////////
/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads)
* via an API that mirrors that of NamedBarrierManager
*
* @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization
**/
template <
class Synchronizer,
uint32_t ThreadCount_
>
struct SyncManager {
// Number of threads participating in the barrier
static constexpr uint32_t ThreadCount = ThreadCount_;
using BarrierSync = cutlass::GenericBarrier<Synchronizer>;
// Underlying type used by all barriers for synchronization.
using T = typename BarrierSync::T;
CUTLASS_DEVICE
static
void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) {
BarrierSync::wait_lt_helper(lock_ptr, thread_idx, flag_idx, count);
}
CUTLASS_DEVICE
static void
wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val);
}
CUTLASS_DEVICE
static void
wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val);
}
CUTLASS_DEVICE
static void
arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) {
BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val);
}
CUTLASS_DEVICE
static void
arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) {
BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass } // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -67,7 +67,7 @@ sm90_get_tma_dispatch_policy() {
constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{}));
constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v<Schedule> ? 256 : 128); constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v<Schedule> ? 256 : 128);
constexpr int ReuseSmemC = sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>; constexpr int ReuseSmemC = (sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>) && (sizeof_bits_v<ElementD> > 8);
constexpr int StagesD = 2; constexpr int StagesD = 2;
constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles; constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles;
@ -98,7 +98,7 @@ sm90_get_epilogue_smem_swizzle_layout_atom() {
} }
// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. // Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
template <class Element, class EpilogueTileType, class Schedule> template <class ElementD, class EpilogueTileType, class Schedule>
constexpr auto constexpr auto
sm90_compute_tile_shape_or_override() { sm90_compute_tile_shape_or_override() {
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) { if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
@ -107,7 +107,12 @@ sm90_compute_tile_shape_or_override() {
return Shape<_128,_32>{}; return Shape<_128,_32>{};
} }
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) { else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
return Shape<_64,_32>{}; if constexpr (sizeof_bits_v<ElementD> == 8) {
return Shape<_64,_64>{};
}
else {
return Shape<_64,_32>{};
}
} }
else { else {
static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule."); static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule.");

View File

@ -34,6 +34,7 @@
#include "cutlass/kernel_hardware_info.hpp" #include "cutlass/kernel_hardware_info.hpp"
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cute/arch/cluster_sm90.hpp"
namespace cutlass::gemm::kernel::detail { namespace cutlass::gemm::kernel::detail {
@ -205,18 +206,14 @@ public:
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__. auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
#if defined(__CUDA_ARCH__)
if (raster_order == RasterOrder::AlongN) { if (raster_order == RasterOrder::AlongN) {
cluster_minor_offset = blockIdx.x; cluster_minor_offset = cta_m_in_cluster;
} }
else { else {
cluster_minor_offset = blockIdx.y; cluster_minor_offset = cta_n_in_cluster;
} }
#else
CUTLASS_ASSERT(false && "This line should never be reached");
#endif
uint64_t cluster_idx_minor, cluster_idx_major; uint64_t cluster_idx_minor, cluster_idx_major;

View File

@ -141,7 +141,7 @@ public:
uint32_t splits_ = 1; uint32_t splits_ = 1;
// Number of tiled k iterations required to compute a single output tile. // Number of tiled k iterations required to compute a single output tile.
uint32_t k_iter_per_tile_ = 0; uint32_t k_tiles_per_output_tile_ = 0;
// Number of stream-K or split-K work units that compute an extra k iteration. // Number of stream-K or split-K work units that compute an extra k iteration.
// This is done to handle residuals in dividing up the k iteration space. // This is done to handle residuals in dividing up the k iteration space.
@ -160,7 +160,7 @@ public:
// Number of tiled k iterations computed by each stream-K work unit. This // Number of tiled k iterations computed by each stream-K work unit. This
// can potentially cover more than one output tile. // can potentially cover more than one output tile.
uint32_t k_iter_per_sk_unit_ = 0; uint32_t k_tiles_per_sk_unit_ = 0;
}; };
// Sink scheduler params as a member // Sink scheduler params as a member
@ -189,9 +189,9 @@ public:
uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l;
// Number of k iterations each tile computes (this is just the number of k iterations // Number of k tile iterations in each output tile
// in the problem's K dimension) uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) /
uint32_t k_iter_per_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) / cute::size<2>(tile_shape); cute::size<2>(tile_shape);
UnderlyingArguments underlying_args; UnderlyingArguments underlying_args;
underlying_args.max_swizzle_size = 1; underlying_args.max_swizzle_size = 1;
@ -216,11 +216,11 @@ public:
// splits is almost certainly nonnegative here (e.g., hw_info.sm_count, // splits is almost certainly nonnegative here (e.g., hw_info.sm_count,
// despite being an int, is a count), so it can safely be converted to unsigned // despite being an int, is a count), so it can safely be converted to unsigned
// in the comparison to avoid a signed-unsigned comparison warning-as-error. // in the comparison to avoid a signed-unsigned comparison warning-as-error.
splits = static_cast<decltype(k_iter_per_tile)>(splits) > k_iter_per_tile ? k_iter_per_tile : splits; splits = static_cast<decltype(k_tiles_per_output_tile)>(splits) > k_tiles_per_output_tile ? k_tiles_per_output_tile : splits;
return get_params_basic( return get_params_basic(
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
splits, k_iter_per_tile, reduction_workspace); splits, k_tiles_per_output_tile, reduction_workspace);
} }
// Calculate the maximum number of blocks from clusters of shape cluster_shape that we // Calculate the maximum number of blocks from clusters of shape cluster_shape that we
@ -229,7 +229,7 @@ public:
uint64_t ctas_per_wave = grid.x * grid.y; uint64_t ctas_per_wave = grid.x * grid.y;
// The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively.
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave); uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile);
uint64_t dp_tiles = output_tiles - sk_tiles; uint64_t dp_tiles = output_tiles - sk_tiles;
// Calculate the number of work units covering the data-parallel and stream-K tiles. // Calculate the number of work units covering the data-parallel and stream-K tiles.
@ -243,7 +243,7 @@ public:
uint64_t dp_units = dp_tiles; uint64_t dp_units = dp_tiles;
// Number of k iterations computed by the stream-K units as a whole // Number of k iterations computed by the stream-K units as a whole
uint64_t k_iter_sk_total = k_iter_per_tile * sk_tiles; uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles;
// If there are stream-K tiles to compute and a sufficiently large number of k iterations // If there are stream-K tiles to compute and a sufficiently large number of k iterations
// across them, they will be covered by a single wave of persistent threadblocks. Thus, there // across them, they will be covered by a single wave of persistent threadblocks. Thus, there
@ -255,7 +255,7 @@ public:
// Calculate the number of stream-K units that would be needed if each stream-K unit // Calculate the number of stream-K units that would be needed if each stream-K unit
// computed the minimum allowable k iterations. Truncate this to be in units of clusters. // computed the minimum allowable k iterations. Truncate this to be in units of clusters.
uint64_t min_sized_sk_units = (k_iter_sk_total / min_iters_per_sk_unit_); uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_);
min_sized_sk_units = (min_sized_sk_units / cute::size(cluster_shape)) * cute::size(cluster_shape); min_sized_sk_units = (min_sized_sk_units / cute::size(cluster_shape)) * cute::size(cluster_shape);
uint64_t sk_units = min(ctas_per_wave, min_sized_sk_units); uint64_t sk_units = min(ctas_per_wave, min_sized_sk_units);
@ -264,7 +264,7 @@ public:
// Short circuit to basic data-parallel decomposition // Short circuit to basic data-parallel decomposition
return get_params_basic( return get_params_basic(
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
1, k_iter_per_tile, reduction_workspace); 1, k_tiles_per_output_tile, reduction_workspace);
} }
// If the number of stream-K units is a multiple of the number of stream-K tiles, then // If the number of stream-K units is a multiple of the number of stream-K tiles, then
@ -274,24 +274,24 @@ public:
uint32_t sk_splits = static_cast<uint32_t>(sk_units / sk_tiles); uint32_t sk_splits = static_cast<uint32_t>(sk_units / sk_tiles);
return get_params_basic( return get_params_basic(
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
sk_splits, k_iter_per_tile, reduction_workspace); sk_splits, k_tiles_per_output_tile, reduction_workspace);
} }
// Number of k iterations computed per stream-K units // Number of k iterations computed per stream-K units
uint64_t k_iter_per_sk_unit = k_iter_sk_total / sk_units; uint64_t k_tiles_per_sk_unit = k_tiles_sk_total / sk_units;
// Number of stream-K units that need to compute extra iterations in order to cover // Number of stream-K units that need to compute extra iterations in order to cover
// the residual k iterations. This assumes that each such unit computes one additional // the residual k iterations. This assumes that each such unit computes one additional
// iteration. // iteration.
uint64_t sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units); uint64_t sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units);
// The division below is guaranteed to be exact because sk_big_units is guaranteed // The division below is guaranteed to be exact because sk_big_units is guaranteed
// to be a multiple of cluster_size (cute::size(cluster_shape)). This is useful because // to be a multiple of cluster_size (cute::size(cluster_shape)). This is useful because
// it allows us to use a block's linearized cluster ID to determine whether it is // it allows us to use a block's linearized cluster ID to determine whether it is
// a big block. The reasoning behind this guarnatee is explained as follows: // a big block. The reasoning behind this guarnatee is explained as follows:
// sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units); // sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units);
// //
// - k_iter_sk_total is a multiple of cluster_size because it is the product // - k_tiles_sk_total is a multiple of cluster_size because it is the product
// of number of tail tiles and the number of k iterations per tile. Because // of number of tail tiles and the number of k iterations per tile. Because
// both the number of output tiles and number of available SMs are rounded // both the number of output tiles and number of available SMs are rounded
// to be multiples of cluster shape, the number of tail tiles // to be multiples of cluster shape, the number of tail tiles
@ -313,12 +313,12 @@ public:
underlying_params.raster_order_, underlying_params.raster_order_,
cluster_shape, cluster_shape,
1, // Static k-splitting factor. Unused for stream-K. 1, // Static k-splitting factor. Unused for stream-K.
k_iter_per_tile, k_tiles_per_output_tile,
static_cast<uint32_t>(sk_big_units_per_cluster), static_cast<uint32_t>(sk_big_units_per_cluster),
reduction_workspace, reduction_workspace,
sk_tiles, sk_tiles,
static_cast<uint32_t>(sk_units), static_cast<uint32_t>(sk_units),
static_cast<uint32_t>(k_iter_per_sk_unit) static_cast<uint32_t>(k_tiles_per_sk_unit)
}; };
} }
@ -338,105 +338,32 @@ public:
CUTLASS_DEVICE CUTLASS_DEVICE
WorkTileInfo WorkTileInfo
get_current_work() const { get_current_work() const {
return get_current_work_for_linear_idx(current_work_linear_idx_); return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params);
} }
CUTLASS_DEVICE CUTLASS_DEVICE
WorkTileInfo static WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx) const { get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) {
if (linear_idx >= scheduler_params.units_per_problem_) { if (linear_idx >= params.units_per_problem_) {
// Invalid work. Return an empty result. // Invalid work. Return an empty result.
return {0, 0, 0, 0, false, 0}; return {0, 0, 0, 0, false, 0};
} }
// Determine whether this work unit is a data-parallel or stream-K work unit // Determine whether this work unit is a data-parallel or stream-K work unit
bool is_stream_k_unit = linear_idx < scheduler_params.sk_units_; bool is_stream_k_unit = linear_idx < params.sk_units_;
bool is_split_k = scheduler_params.splits_ > 1; bool is_split_k = params.splits_ > 1;
// Bypass the stream-K scheduling logic for basic data-parallel or split-K work
if (is_split_k || !is_stream_k_unit) { if (is_split_k || !is_stream_k_unit) {
// The linearized ID space is in terms of work units, rather than tiles. However, // Bypass the stream-K scheduling logic for basic data-parallel or split-K work
// to compute the correct block offset for a data-parallel tile, we must convert return set_non_stream_k_work(linear_idx, params, is_split_k);
// the current ID to the data-parallel tile it corresponds to. Each data-parallel }
// unit maps to a single data-parallel tile, but each stream-K unit can map to more else {
// than one tile. Thus, we must offset the work-unit ID among the data-parallel units // This is a stream-K work unit
// by the total number of output tiles that will be computed by stream-K units. WorkTileInfo work_tile_info;
// set_stream_k_work(params, linear_idx, work_tile_info, /*new_unit = */ true);
// The logic below also works for the split-K case, in which sk_units_ and sk_tiles_ return work_tile_info;
// are each 0.
uint64_t linear_work_idx = linear_idx - scheduler_params.sk_units_ + scheduler_params.sk_tiles_;
// Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices
uint64_t work_idx_l, remainder;
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_work_idx);
uint64_t work_idx_k = 0;
if (is_split_k) {
scheduler_params.divmod_k_(work_idx_k, remainder, remainder);
}
uint64_t cta_per_grid_dim, dontcare;
scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
cta_per_grid_dim,
scheduler_params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_minor_,
scheduler_params.divmod_cluster_blk_major_,
scheduler_params.log_swizzle_size_,
scheduler_params.raster_order_);
bool is_final_split = (work_idx_k == scheduler_params.splits_ - 1);
uint32_t k_iter = scheduler_params.k_iter_per_tile_;
if (is_split_k) {
// Determine the number of iterations and starting iteration of this split.
// Doing so requires accounting for residual iterations, which are handled
// by the first big_units_ splits (with big_units_ = tiles % sm_count).
// Offsets for "normal" units. No additional k iterations are performed,
// and big_units_ "big" units preceded us, each of which performed one
// additional iteration. Thus, we must increase our split starting offset
// by big_units_.
int additional_k_iter = 0;
int split_start_offset = scheduler_params.big_units_;
if (work_idx_k < scheduler_params.big_units_) {
// Offsets for "big" units. One additional k iteration is performed,
// and each split preceding us was a big unit, so we must increase
// our split starting offset by our split ID (work_idx_k).
additional_k_iter = 1;
split_start_offset = work_idx_k;
}
// Set up k iteration count and split starting iteration assuming the
// iteration space is evenly split.
k_iter /= scheduler_params.splits_;
work_idx_k *= k_iter;
// Apply any fixup needed to handle residuals
work_idx_k += split_start_offset;
k_iter += additional_k_iter;
}
return {
work_idx_m,
work_idx_n,
static_cast<int32_t>(work_idx_k),
static_cast<int32_t>(work_idx_l),
true,
scheduler_params.k_iter_per_tile_,
k_iter,
k_iter, // remaining iterations
is_final_split
};
} }
// This is a stream-K work unit
WorkTileInfo work_tile_info;
set_stream_k_work(linear_idx, work_tile_info, /*new_unit = */ true);
return work_tile_info;
} }
// Returns whether the current work_tile_info passed in should continue to be used. This // Returns whether the current work_tile_info passed in should continue to be used. This
@ -446,13 +373,24 @@ public:
CUTLASS_DEVICE CUTLASS_DEVICE
bool bool
continue_current_work(WorkTileInfo& work_tile_info) const { continue_current_work(WorkTileInfo& work_tile_info) const {
return continue_current_work_for_linear_idx(
current_work_linear_idx_, work_tile_info, scheduler_params);
}
CUTLASS_DEVICE static
bool
continue_current_work_for_linear_idx(
uint64_t linear_idx,
WorkTileInfo& work_tile_info,
Params const& params) {
work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count; work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count;
if (work_tile_info.k_tile_remaining == 0) { if (work_tile_info.k_tile_remaining == 0) {
return false; return false;
} }
set_stream_k_work(current_work_linear_idx_, work_tile_info, /* new_unit = */ false); set_stream_k_work(params, linear_idx, work_tile_info, /* new_unit = */ false);
return true; return true;
} }
@ -495,6 +433,14 @@ public:
/*truncate_by_problem_size=*/false); /*truncate_by_problem_size=*/false);
} }
// Returns whether fixup is needed for `work_tile_info`.
CUTLASS_HOST_DEVICE
static bool
requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) {
// Fixup is not needed for data-parallel tiles
return work_tile_info.k_tile_count != params.k_tiles_per_output_tile_;
}
// Performs the reduction across splits for a given output tile. // Performs the reduction across splits for a given output tile.
template <class FrgTensorC> template <class FrgTensorC>
CUTLASS_DEVICE CUTLASS_DEVICE
@ -505,13 +451,25 @@ public:
FrgTensorC& accumulators, FrgTensorC& accumulators,
uint32_t num_barriers, uint32_t num_barriers,
uint32_t barrier_idx) { uint32_t barrier_idx) {
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, 2>;
return fixup_helper<FrgTensorC, BarrierManager>(
params, work_tile_info, accumulators, num_barriers, barrier_idx);
}
// Helper for performing the reduction across splits for a given output tile.
template <class FrgTensorC, class BarrierManager>
CUTLASS_DEVICE
static void
fixup_helper(
Params const& params,
WorkTileInfo const& work_tile_info,
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx) {
using ElementAccumulator = typename FrgTensorC::value_type; using ElementAccumulator = typename FrgTensorC::value_type;
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, 2>; if (!requires_fixup(params, work_tile_info)) {
if (work_tile_info.k_tile_count == params.k_iter_per_tile_) {
// Fixup is not needed for data-parallel tiles
return; return;
} }
@ -619,21 +577,23 @@ public:
} }
} }
else { else {
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
uint64_t cta_per_grid_dim; uint64_t cta_per_grid_dim;
uint64_t cluster_dim_idx; uint64_t cluster_dim_idx;
if (params.raster_order_ == RasterOrder::AlongN) { if (params.raster_order_ == RasterOrder::AlongN) {
uint64_t block_idx_m = (work_tile_info.M_idx - blockIdx.x) / gridDim.x; uint64_t block_idx_m = (work_tile_info.M_idx - cta_m_in_cluster) / cute::size<0>(params.cluster_shape_);
uint64_t block_idx_n = work_tile_info.N_idx; uint64_t block_idx_n = work_tile_info.N_idx;
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor * cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n; params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n;
cluster_dim_idx = blockIdx.x; cluster_dim_idx = cta_m_in_cluster;
} }
else { else {
uint64_t block_idx_m = work_tile_info.M_idx; uint64_t block_idx_m = work_tile_info.M_idx;
uint64_t block_idx_n = (work_tile_info.N_idx - blockIdx.y) / gridDim.y; uint64_t block_idx_n = (work_tile_info.N_idx - cta_n_in_cluster) / cute::size<1>(params.cluster_shape_);
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor * cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m; params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m;
cluster_dim_idx = blockIdx.y; cluster_dim_idx = cta_n_in_cluster;
} }
uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim; uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim;
@ -646,7 +606,7 @@ public:
get_workspace_size( get_workspace_size(
Arguments const& args, Arguments const& args,
ProblemShape problem_shape, ProblemShape problem_shape,
KernelHardwareInfo const& hw_info, KernelHardwareInfo const& hw_info,
uint32_t mma_warp_groups) { uint32_t mma_warp_groups) {
int barrier_workspace_size = 0; int barrier_workspace_size = 0;
@ -715,7 +675,7 @@ private:
// Construct a layout for the indexed tensor. The main purpose of this new layout is to // Construct a layout for the indexed tensor. The main purpose of this new layout is to
// override the k extent to support cases in which the split computes a number of iterations // override the k extent to support cases in which the split computes a number of iterations
// not equal to total_tile_k_iter / splits. A common example of this is in stream-K is when a // not equal to total_k_tiles / splits. A common example of this is in stream-K is when a
// unit computes the final 20 of the total 32 k iterations of the output tile. In this case, // unit computes the final 20 of the total 32 k iterations of the output tile. In this case,
// set splits = 32 and the split index (K_idx) to 11. The zipped divide above results in each // set splits = 32 and the split index (K_idx) to 11. The zipped divide above results in each
// of the splits computing only one k iteration. // of the splits computing only one k iteration.
@ -728,12 +688,13 @@ private:
// Returns the number of stream-K tiles that will be computed amongst `output_tiles` total // Returns the number of stream-K tiles that will be computed amongst `output_tiles` total
// output tiles on a device with `ctas_per_wave` CTAs in each wave. // output tiles on a device with `ctas_per_wave` CTAs in each wave.
static uint32_t static uint32_t
get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave) { get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave, uint32_t k_tiles_per_output_tile) {
uint32_t full_waves = static_cast<uint32_t>(output_tiles / ctas_per_wave); uint32_t full_waves = static_cast<uint32_t>(output_tiles / ctas_per_wave);
uint32_t total_waves = static_cast<uint32_t>((output_tiles + ctas_per_wave - 1) / ctas_per_wave); uint32_t total_waves = static_cast<uint32_t>((output_tiles + ctas_per_wave - 1) / ctas_per_wave);
if (full_waves == total_waves) { if (full_waves == total_waves || k_tiles_per_output_tile == 1) {
// No quantization. All tiles will be data-parallel tiles. // All tiles will be data-parallel tiles if there is either no quantization
// or if there is no work to be split.
return 0; return 0;
} }
@ -811,9 +772,12 @@ private:
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
} }
uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(TileShape{}) - 1) /
cute::size<2>(TileShape{});
dim3 grid = get_grid_shape(problem_shape_mnkl, TileShape{}, cluster_shape, {0, sm_count}, args); dim3 grid = get_grid_shape(problem_shape_mnkl, TileShape{}, cluster_shape, {0, sm_count}, args);
uint64_t ctas_per_wave = grid.x * grid.y; uint64_t ctas_per_wave = grid.x * grid.y;
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave); uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile);
barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups); barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups);
reduction_workspace_size = get_reduction_workspace_size<ElementAccumulator>(sk_tiles); reduction_workspace_size = get_reduction_workspace_size<ElementAccumulator>(sk_tiles);
@ -829,10 +793,10 @@ private:
uint32_t blocks_l, uint32_t blocks_l,
ClusterShape cluster_shape, ClusterShape cluster_shape,
uint32_t splits, uint32_t splits,
uint32_t k_iter_per_tile, uint32_t k_tiles_per_output_tile,
void* reduction_workspace) { void* reduction_workspace) {
uint32_t big_units = k_iter_per_tile % splits; uint32_t big_units = k_tiles_per_output_tile % splits;
return { return {
underlying_params.divmod_cluster_shape_major_, underlying_params.divmod_cluster_shape_major_,
@ -845,7 +809,7 @@ private:
underlying_params.raster_order_, underlying_params.raster_order_,
cluster_shape, cluster_shape,
splits, splits,
k_iter_per_tile, k_tiles_per_output_tile,
big_units, big_units,
reduction_workspace reduction_workspace
}; };
@ -855,8 +819,12 @@ private:
// is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining
// iterations) is used to find the next tile in the current work unit. // iterations) is used to find the next tile in the current work unit.
CUTLASS_DEVICE CUTLASS_DEVICE
void static void
set_stream_k_work(uint64_t linear_idx, WorkTileInfo& work_tile_info, bool new_unit) const { set_stream_k_work(
Params const& params,
uint64_t linear_idx,
WorkTileInfo& work_tile_info,
bool new_unit) {
// In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K
// threadblock individually. For the most part, the set of K iterations corresponding to stream-K // threadblock individually. For the most part, the set of K iterations corresponding to stream-K
// work was divided amongst stream-K threadblocks, and a threadblock determined which tile // work was divided amongst stream-K threadblocks, and a threadblock determined which tile
@ -872,15 +840,15 @@ private:
// //
// To do so, we divide up the linearized stream-K units into clusters and share the same K // To do so, we divide up the linearized stream-K units into clusters and share the same K
// offsets for work within clusters. // offsets for work within clusters.
auto cluster_linear_work_idx = linear_idx / size(scheduler_params.cluster_shape_); auto cluster_linear_work_idx = linear_idx / size(params.cluster_shape_);
// Determine the starting k iteration computed by this stream-K work unit // Determine the starting k iteration computed by this stream-K work unit
uint32_t unit_iter_start = scheduler_params.k_iter_per_sk_unit_ * cluster_linear_work_idx; uint32_t unit_iter_start = params.k_tiles_per_sk_unit_ * cluster_linear_work_idx;
// Adjust the starting position and number of k iterations for "big units," which // Adjust the starting position and number of k iterations for "big units," which
// compute one extra iteration. These are the first big_units_ units in the // compute one extra iteration. These are the first big_units_ units in the
// linearized ID space. // linearized ID space.
bool is_big_unit = cluster_linear_work_idx < scheduler_params.big_units_; bool is_big_unit = cluster_linear_work_idx < params.big_units_;
if (is_big_unit) { if (is_big_unit) {
// Since the "big units" are the first units in the linearized ID space, each // Since the "big units" are the first units in the linearized ID space, each
// of the units preceding this big unit computed one extra iteration. Thus, // of the units preceding this big unit computed one extra iteration. Thus,
@ -889,16 +857,16 @@ private:
unit_iter_start += cluster_linear_work_idx; unit_iter_start += cluster_linear_work_idx;
} else { } else {
// Increment by one for each of the big clusters (since all big units precede this unit) // Increment by one for each of the big clusters (since all big units precede this unit)
unit_iter_start += scheduler_params.big_units_; unit_iter_start += params.big_units_;
} }
uint32_t unit_iters; uint32_t unit_iters;
if (new_unit) { if (new_unit) {
unit_iters = scheduler_params.k_iter_per_sk_unit_; unit_iters = params.k_tiles_per_sk_unit_;
// Only adjust iteration count for big unit if we are initializing this // Only adjust iteration count for big unit if we are initializing this
// work unit. For existing work units, the extra iteration for big units // work unit. For existing work units, the extra iteration for big units
// has already been accounted for in k_iter_reamaining // has already been accounted for in k_tiles_reamaining
if (is_big_unit) { if (is_big_unit) {
++unit_iters; ++unit_iters;
} }
@ -917,22 +885,21 @@ private:
// for them to be computed later, so as to reduce the likelihood of blocking // for them to be computed later, so as to reduce the likelihood of blocking
// on other work. // on other work.
uint32_t unit_iter_end = unit_iter_start + unit_iters - 1; uint32_t unit_iter_end = unit_iter_start + unit_iters - 1;
uint32_t true_tile_id = unit_iter_end / scheduler_params.k_iter_per_tile_; uint32_t true_tile_id = unit_iter_end / params.k_tiles_per_output_tile_;
uint32_t true_tile_iter_start = true_tile_id * scheduler_params.k_iter_per_tile_; uint32_t true_tile_iter_start = true_tile_id * params.k_tiles_per_output_tile_;
uint32_t true_tile_iter_end = true_tile_iter_start + scheduler_params.k_iter_per_tile_; uint32_t true_tile_iter_end = true_tile_iter_start + params.k_tiles_per_output_tile_;
// Bring the linearized tile ID back into the space of tiles, rather than clusters // Bring the linearized tile ID back into the space of tiles, rather than clusters
true_tile_id *= size(scheduler_params.cluster_shape_); true_tile_id *= size(params.cluster_shape_);
auto cluster_dim0 = cute::size<0>(scheduler_params.cluster_shape_); auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
auto cluster_dim1 = cute::size<1>(scheduler_params.cluster_shape_);
// The final linearized tile ID is in units of the cluster dimension over which we rasterize. // The final linearized tile ID is in units of the cluster dimension over which we rasterize.
if (scheduler_params.raster_order_ == RasterOrder::AlongN) { if (params.raster_order_ == RasterOrder::AlongN) {
true_tile_id += (blockIdx.y % cluster_dim1) * cluster_dim0; true_tile_id += cta_n_in_cluster * cute::size<0>(params.cluster_shape_);
} }
else { else {
true_tile_id += (blockIdx.x % cluster_dim0) * cluster_dim1; true_tile_id += cta_m_in_cluster * cute::size<1>(params.cluster_shape_);
} }
// The unit's starting k iteration in the current tile is either the starting // The unit's starting k iteration in the current tile is either the starting
@ -948,19 +915,18 @@ private:
uint32_t tile_iters = tile_iter_end - tile_iter_start; uint32_t tile_iters = tile_iter_end - tile_iter_start;
uint64_t work_idx_l, remainder; uint64_t work_idx_l, remainder;
scheduler_params.divmod_batch_(work_idx_l, remainder, true_tile_id); params.divmod_batch_(work_idx_l, remainder, true_tile_id);
uint64_t cta_per_grid_dim, dontcare; uint64_t cta_per_grid_dim, dontcare;
scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
cta_per_grid_dim, cta_per_grid_dim,
scheduler_params.divmod_cluster_shape_major_, params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_minor_, params.divmod_cluster_shape_minor_,
scheduler_params.divmod_cluster_blk_major_, params.divmod_cluster_blk_major_,
scheduler_params.log_swizzle_size_, params.log_swizzle_size_,
scheduler_params.raster_order_); params.raster_order_);
// //
// Update the work_tile_info // Update the work_tile_info
@ -971,11 +937,11 @@ private:
work_tile_info.N_idx = work_idx_n; work_tile_info.N_idx = work_idx_n;
work_tile_info.L_idx = static_cast<int32_t>(work_idx_l); work_tile_info.L_idx = static_cast<int32_t>(work_idx_l);
// Set the k offset to be the starting k iteration for this tile // Set the k offset to be the starting k tile for this output tile
work_tile_info.K_idx = static_cast<int32_t>(tile_iter_start - true_tile_iter_start); work_tile_info.K_idx = static_cast<int32_t>(tile_iter_start - true_tile_iter_start);
// Set the split count to be the number of k iterations in the tile // Set the split count to be the number of k tiles in the output tile
work_tile_info.splits = scheduler_params.k_iter_per_tile_; work_tile_info.splits = params.k_tiles_per_output_tile_;
// Any checks for invalid work units should be done prior to this call // Any checks for invalid work units should be done prior to this call
work_tile_info.is_valid_tile = true; work_tile_info.is_valid_tile = true;
@ -987,6 +953,89 @@ private:
// the output tile in question // the output tile in question
work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end); work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end);
} }
// Returns a WorkTileInfo to be computed for either the data-parallel or split-K
// work unit identified by the provided linear ID.
CUTLASS_DEVICE
static WorkTileInfo
set_non_stream_k_work(uint64_t linear_idx, Params const& params, bool is_split_k) {
// The linearized ID space is in terms of work units, rather than tiles. However,
// to compute the correct block offset for a data-parallel tile, we must convert
// the current ID to the data-parallel tile it corresponds to. Each data-parallel
// unit maps to a single data-parallel tile, but each stream-K unit can map to more
// than one tile. Thus, we must offset the work-unit ID among the data-parallel units
// by the total number of output tiles that will be computed by stream-K units.
//
// The logic below also works for the split-K case, in which sk_units_ and sk_tiles_
// are each 0.
uint64_t linear_work_idx = linear_idx - params.sk_units_ + params.sk_tiles_;
// Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices
uint64_t work_idx_l, remainder;
params.divmod_batch_(work_idx_l, remainder, linear_work_idx);
uint64_t work_idx_k = 0;
if (is_split_k) {
params.divmod_k_(work_idx_k, remainder, remainder);
}
uint64_t cta_per_grid_dim, dontcare;
params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
cta_per_grid_dim,
params.divmod_cluster_shape_major_,
params.divmod_cluster_shape_minor_,
params.divmod_cluster_blk_major_,
params.log_swizzle_size_,
params.raster_order_);
bool is_final_split = (work_idx_k == params.splits_ - 1);
uint32_t k_tiles = params.k_tiles_per_output_tile_;
if (is_split_k) {
// Determine the number of iterations and starting iteration of this split.
// Doing so requires accounting for residual iterations, which are handled
// by the first big_units_ splits (with big_units_ = tiles % sm_count).
// Offsets for "normal" units. No additional k iterations are performed,
// and big_units_ "big" units preceded us, each of which performed one
// additional iteration. Thus, we must increase our split starting offset
// by big_units_.
int additional_k_tiles = 0;
int split_start_offset = params.big_units_;
if (work_idx_k < params.big_units_) {
// Offsets for "big" units. One additional k iteration is performed,
// and each split preceding us was a big unit, so we must increase
// our split starting offset by our split ID (work_idx_k).
additional_k_tiles = 1;
split_start_offset = work_idx_k;
}
// Set up k iteration count and split starting iteration assuming the
// iteration space is evenly split.
k_tiles /= params.splits_;
work_idx_k *= k_tiles;
// Apply any fixup needed to handle residuals
work_idx_k += split_start_offset;
k_tiles += additional_k_tiles;
}
return {
work_idx_m,
work_idx_n,
static_cast<int32_t>(work_idx_k),
static_cast<int32_t>(work_idx_l),
true,
params.k_tiles_per_output_tile_,
k_tiles,
k_tiles, // remaining iterations
is_final_split
};
}
}; };
} // namespace cutlass::gemm::kernel::detail } // namespace cutlass::gemm::kernel::detail

View File

@ -28,7 +28,7 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* *
**************************************************************************************************/ **************************************************************************************************/
/*! /*!
\file \file
\brief Boost-like numeric conversion operator for CUTLASS numeric types \brief Boost-like numeric conversion operator for CUTLASS numeric types
*/ */
@ -55,7 +55,7 @@ enum class FloatRoundStyle {
round_indeterminate, ///< rounding mode unknown round_indeterminate, ///< rounding mode unknown
round_toward_zero, ///< round toward zero round_toward_zero, ///< round toward zero
round_to_nearest, ///< round to nearest even round_to_nearest, ///< round to nearest even
round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type
round_toward_infinity, ///< round toward infinity round_toward_infinity, ///< round toward infinity
round_toward_neg_infinity, ///< round toward negative infinity round_toward_neg_infinity, ///< round toward negative infinity
round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero
@ -561,7 +561,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_to_nearest> {
// Note, the following is intentionally commented out. TF32 // Note, the following is intentionally commented out. TF32
// does not define the low order bits, so they may be left in // does not define the low order bits, so they may be left in
// an undefined state. // an undefined state.
// //
// By not truncating these bit explicitly, we avoid an extra logical // By not truncating these bit explicitly, we avoid an extra logical
// operation. // operation.
@ -657,7 +657,7 @@ template <
struct NumericConverterFastF32 { struct NumericConverterFastF32 {
// result_type holds big tfloat32_t at idx(0) and small tfloat32_t at idx(1) // result_type holds big tfloat32_t at idx(0) and small tfloat32_t at idx(1)
using result_type = Array<tfloat32_t, 2>; using result_type = Array<tfloat32_t, 2>;
// source data type // source data type
using source_type = float; using source_type = float;
@ -708,7 +708,7 @@ struct NumericConverterClamp {
NumericConverter<result_type, source_type> convert_op; NumericConverter<result_type, source_type> convert_op;
result_type const kClamp_max = platform::numeric_limits<result_type>::max(); result_type const kClamp_max = platform::numeric_limits<result_type>::max();
result_type const kClamp_min = platform::numeric_limits<result_type>::lowest(); result_type const kClamp_min = platform::numeric_limits<result_type>::lowest();
if (s < (source_type)kClamp_min) if (s < (source_type)kClamp_min)
return kClamp_min; return kClamp_min;
if (s > (source_type)kClamp_max) if (s > (source_type)kClamp_max)
return kClamp_max; return kClamp_max;
@ -848,7 +848,7 @@ struct NumericArrayConverter<half_t, float, 2, FloatRoundStyle::round_to_nearest
result[0] = convert_(source[0]); result[0] = convert_(source[0]);
result[1] = convert_(source[1]); result[1] = convert_(source[1]);
#endif #endif
return result; return result;
} }
@ -878,7 +878,7 @@ struct NumericArrayConverter<float, half_t, 2, Round> {
result[0] = convert_(source[0]); result[0] = convert_(source[0]);
result[1] = convert_(source[1]); result[1] = convert_(source[1]);
#endif #endif
return result; return result;
} }
@ -1044,7 +1044,7 @@ struct NumericArrayConverter<bfloat16_t, float, N, Round> {
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers // Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
((__CUDACC_VER_MAJOR__ > 10) || \ ((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
@ -1066,7 +1066,7 @@ struct NumericArrayConverter<int8_t, int, 1, Round> {
result_type result; result_type result;
result[0] = convert_element_(source[0]); result[0] = convert_element_(source[0]);
return result; return result;
} }
@ -1189,7 +1189,7 @@ struct NumericArrayConverter<uint8_t, int, 1, Round> {
result_type result; result_type result;
result[0] = convert_element_(source[0]); result[0] = convert_element_(source[0]);
return result; return result;
} }

View File

@ -215,10 +215,17 @@ class GemmArguments2x(ArgumentBase):
else: else:
self.batch_count = 1 self.batch_count = 1
self.batched_stride_A = self.problem_size.m() * self.problem_size.k() if "batch_strides" in kwargs:
self.batched_stride_B = self.problem_size.n() * self.problem_size.k() self.batched_stride_A = kwargs["batch_strides"]["A"]
self.batched_stride_C = self.problem_size.m() * self.problem_size.n() self.batched_stride_B = kwargs["batch_strides"]["B"]
self.batched_stride_D = self.problem_size.m() * self.problem_size.n() self.batched_stride_C = kwargs["batch_strides"]["C"]
self.batched_stride_D = kwargs["batch_strides"]["D"]
else:
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
if self.bias: if self.bias:
self.batched_stride_C = self.problem_size.n() self.batched_stride_C = self.problem_size.n()

View File

@ -132,9 +132,9 @@ class KernelsForDataType:
""" """
# Determine the leading dimension of the shape # Determine the leading dimension of the shape
if layout == cutlass.LayoutType.ColumnMajor: if layout == cutlass.LayoutType.ColumnMajor:
ld = shape[0] ld = shape[-2]
elif layout == cutlass.LayoutType.RowMajor: elif layout == cutlass.LayoutType.RowMajor:
ld = shape[1] ld = shape[-1]
elif layout == cutlass.LayoutType.TensorNHWC: elif layout == cutlass.LayoutType.TensorNHWC:
ld = shape[-1] ld = shape[-1]
else: else:

View File

@ -114,6 +114,8 @@
args.sync() args.sync()
""" """
from math import prod
import cutlass_bindings import cutlass_bindings
import cutlass import cutlass
@ -442,6 +444,113 @@ class Gemm(OperationBase):
compiler.add_module([self.operation,]) compiler.add_module([self.operation,])
return self.operation return self.operation
def _verify_rank(self, tensor):
"""
Verifies that ``tensor`` has rank greater than 1
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
"""
if len(tensor.shape) < 2:
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
def _get_batch_count(self, A, B, C, D) -> int:
"""
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
tensors match in batch size. Presence of a batch dimension is detected by one of the
tensors being rank 3. If a batch dimension is present, it must be present in one of
operands A, B, or C (but need not be in all), and must be present in D.
:param A: tensor A
:type A: numpy/cupy/torch array/tensor object
:param B: tensor B
:type B: numpy/cupy/torch array/tensor object
:param C: tensor C
:type C: numpy/cupy/torch array/tensor object
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple of batch count dimensions
:rtype: tuple
"""
A_batch = A.shape[:-2] if len(A.shape) > 2 else tuple()
B_batch = B.shape[:-2] if len(B.shape) > 2 else tuple()
C_batch = C.shape[:-2] if len(C.shape) > 2 else tuple()
D_batch = D.shape[:-2] if len(D.shape) > 2 else tuple()
if len(D_batch) > 0 and D_batch not in [A_batch, B_batch, C_batch]:
raise Exception(f"Batch count in D must be present in one of operands A, B, and C. "
f"Batch counts are: A={A_batch}, B={B_batch}, C={C_batch}, D={D_batch}")
for batch_shape in [A_batch, B_batch, C_batch]:
if len(batch_shape) > 0 and batch_shape != D_batch:
raise Exception(f"Batch count for all other operands must either match that of D or be zero."
f"Received batch shape of {batch_shape}, which does not match that of D of {D_batch}.")
return D_batch
def _get_batch_stride(self, tensor) -> int:
"""
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
:param tensor: tensor object to process
:type tensor: numpy/cupy/torch array/tensor object
:return: stride between each matrix in the batch
:rtype: int
"""
if len(tensor.shape) > 2:
return tensor.shape[-2] * tensor.shape[-1]
else:
return 0
def _get_problem_args(self, A, B, C, D) -> tuple:
"""
Returns the problem size and GEMM universal mode to use for the
given operands.
:param A: tensor A
:type A: numpy/cupy/torch array/tensor object
:param B: tensor B
:type B: numpy/cupy/torch array/tensor object
:param C: tensor C
:type C: numpy/cupy/torch array/tensor object
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple containing the problem size (cutlass_bindings.gemm.GemmCoord), the GEMM mode (cutlass_bindings.gemm.Mode), and the batch count (int)
:rtype: tuple
"""
M, K = A.shape[-2:]
N = B.shape[-1]
mode = cutlass_bindings.gemm.Mode.Gemm
batch_count = self._get_batch_count(A, B, C, D)
returned_batch_count = prod(batch_count) if len(batch_count) > 0 else 1
# If we are running a batched GEMM in which there is a nonzero batch stride
# only for A, then we can fold the batched dimension of A into the M dimension
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
# and C are row major. A similar operation can be performed if only B has a nonzero
# batch dimension
if len(batch_count) > 0:
A_row = self._layout_a == cutlass.LayoutType.RowMajor
B_row = self._layout_b == cutlass.LayoutType.RowMajor
C_row = self._layout_c == cutlass.LayoutType.RowMajor
batched = lambda x : len(x.shape) == 2 + len(batch_count)
if batched(A) and not batched(B) and batched(C) and A_row and C_row:
M *= prod(batch_count)
returned_batch_count = 1
elif not batched(A) and batched(B) and batched(C) and not B_row and not C_row:
N *= prod(batch_count)
returned_batch_count = 1
else:
mode = cutlass_bindings.gemm.Mode.Batched
return cutlass_bindings.gemm.GemmCoord(M, N, K), mode, returned_batch_count
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
""" """
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
@ -461,8 +570,7 @@ class Gemm(OperationBase):
f'layout of ({ref_type}, {ref_layout}).') f'layout of ({ref_type}, {ref_layout}).')
def run(self, A=None, B=None, C=None, D=None, def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, batch_count: int = 1, alpha=None, beta=None, sync: bool = True, print_module: bool = False) -> GemmArguments:
sync: bool = True, print_module: bool = False) -> GemmArguments:
""" """
Runs the kernel currently specified. If it has not already been, the kernel is emitted and Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the compiled. Tensors holding operands and outputs of the kernel are sourced either from the
@ -481,8 +589,6 @@ class Gemm(OperationBase):
:param D: tensor representing data type and layout of operand D :param D: tensor representing data type and layout of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C :param beta: scalar parameter beta from GEMM operation that scales operand C
:param batch_count: number of GEMMs in the batch
:type batch_count: int
:param sync: whether the call should wait for the kernel to complete before returning :param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool :type sync: bool
:param print_module: whether to print the emitted C++ code :param print_module: whether to print the emitted C++ code
@ -491,9 +597,6 @@ class Gemm(OperationBase):
:return: arguments passed in to the kernel :return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmArguments :rtype: cutlass.backend.GemmArguments
""" """
if batch_count < 1:
raise Exception(f"Invalid batch count {batch_count}. Value must be an integer >= 1.")
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
@ -501,20 +604,31 @@ class Gemm(OperationBase):
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
self._verify_rank(A)
self._verify_rank(B)
self._verify_rank(C)
self._verify_rank(D)
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a) alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a)
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b) alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b)
alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c) alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c)
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module) alignment_C=alignment_c, print_module=print_module)
problem_size = cutlass_bindings.gemm.GemmCoord(A.shape[0], B.shape[1], A.shape[1]) problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
if batch_count == 1: if mode == cutlass_bindings.gemm.Mode.Gemm or batch_count == 1:
mode = cutlass_bindings.gemm.Mode.Gemm
kwargs = {'split_k_slices': 1} kwargs = {'split_k_slices': 1}
else: else:
mode = cutlass_bindings.gemm.Mode.Batched kwargs = {
kwargs = {'batch': batch_count} 'batch': batch_count,
'batch_strides': {
'A': self._get_batch_stride(A),
'B': self._get_batch_stride(B),
'C': self._get_batch_stride(C),
'D': self._get_batch_stride(D)
}
}
arguments = GemmArguments( arguments = GemmArguments(
operation=self.operation, problem_size=problem_size, operation=self.operation, problem_size=problem_size,

View File

@ -0,0 +1,139 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level tests for running batched GEMMs
"""
from functools import partial
from math import prod
import cutlass
import logging
import torch
import unittest
from cutlass.backend.test.utils import LayoutCombination, add_test_gemm
from cutlass.backend.utils.device import device_cc
cutlass.set_log_level(logging.WARNING)
torch.manual_seed(2023)
def pytorch_reference(A, B, C, alpha, beta):
# Get the batch count. Assume that any of A, B, and C
# with a batch dimension ahve matching batch count. Thus,
# we break out of the loop once we have found the first
# tensor containing a batch dimension.
batch_count = (1,)
for tensor in [A, B, C]:
if len(tensor.shape) > 2:
batch_count = tensor.shape[:-2]
break
int_batch_count = prod(batch_count)
def add_batch(tensor):
if len(tensor.shape) == 2:
return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1)
else:
return tensor.reshape(-1, tensor.size(-2), tensor.size(-1))
# Reshape tensors to have batch dimension
A = add_batch(A)
B = add_batch(B)
C = add_batch(C)
ret = (torch.bmm(A, B) * alpha) + (C * beta)
reshape_vals = batch_count + C.shape[-2:]
return ret.reshape(*reshape_vals)
def initialize(rows, cols, batch):
tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half()
if len(batch) > 0 and prod(batch) > 1:
reshape_vals = batch + (rows, cols)
return tensor.reshape(*reshape_vals)
else:
return tensor.reshape(rows, cols)
class GemmF16Batched(unittest.TestCase):
def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool):
M = 512
N = 256
K = 128
alpha = 1.
beta = 2.
A = initialize(M, K, batch_count if batch_A else (1,))
B = initialize(K, N, batch_count if batch_B else (1,))
C = initialize(M, N, batch_count if batch_C else (1,))
D = initialize(M, N, batch_count)
plan = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass.DataType.f32)
plan.run(A, B, C, D, alpha, beta)
reference = pytorch_reference(A, B, C, alpha, beta)
assert reference.equal(D)
def test_batched_ABC(self):
self.run_batched((3,), True, True, True)
self.run_batched((2, 3), True, True, True)
def test_batched_AB(self):
self.run_batched((3,), True, True, False)
self.run_batched((2, 3), True, True, False)
def test_batched_AC(self):
self.run_batched((3,), True, False, True)
self.run_batched((2, 3), True, False, True)
def test_batched_BC(self):
self.run_batched((3,), False, True, True)
self.run_batched((2, 3), False, True, True)
def test_batched_A(self):
self.run_batched((3,), True, False, False)
self.run_batched((2, 3), True, False, False)
def test_batched_B(self):
self.run_batched((3,), False, True, False)
self.run_batched((2, 3), False, True, False)
def test_batched_C(self):
self.run_batched((3,), False, False, True)
self.run_batched((2, 3), False, False, True)
if __name__ == '__main__':
unittest.main()

View File

@ -61,7 +61,7 @@ __global__ void convert(
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Destination, typename Source, int Count> template <typename Destination, typename Source, int Count>
void run_test() { void run_test(const char dest_name[], const char source_name[]) {
const int kN = Count; const int kN = Count;
dim3 grid(1, 1); dim3 grid(1, 1);
@ -84,7 +84,10 @@ void run_test() {
destination.sync_host(); destination.sync_host();
for (int i = 0; i < kN; ++i) { for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]))
<< "Destination type: " << dest_name
<< ", Source type: " << source_name
<< ", Count: " << Count;
} }
} }
@ -97,15 +100,19 @@ void run_test() {
TEST(NumericConversion, f32_to_f16_rn) { TEST(NumericConversion, f32_to_f16_rn) {
int const kN = 1; int const kN = 1;
using Source = float; using Source = float;
const char source_name[] = "float";
using Destination = cutlass::half_t; using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f32x8_to_f16x8_rn) { TEST(NumericConversion, f32x8_to_f16x8_rn) {
int const kN = 8; int const kN = 8;
using Source = float; using Source = float;
const char source_name[] = "float";
using Destination = cutlass::half_t; using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
@ -113,15 +120,19 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) {
TEST(NumericConversion, f16_to_f32_rn) { TEST(NumericConversion, f16_to_f32_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::half_t; using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = float; using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f16x8_to_f32x8_rn) { TEST(NumericConversion, f16x8_to_f32x8_rn) {
int const kN = 8; int const kN = 8;
using Source = cutlass::half_t; using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = float; using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
@ -129,86 +140,109 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) {
TEST(NumericConversion, f32_to_fe4m3_rn) { TEST(NumericConversion, f32_to_fe4m3_rn) {
int const kN = 1; int const kN = 1;
using Source = float; using Source = float;
const char source_name[] = "float";
using Destination = cutlass::float_e4m3_t; using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f32_to_fe4m3_rn_array) { TEST(NumericConversion, f32_to_fe4m3_rn_array) {
int const kN = 27; int const kN = 27;
using Source = float; using Source = float;
const char source_name[] = "float";
using Destination = cutlass::float_e4m3_t; using Destination = cutlass::float_e4m3_t;
const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(); test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f32_to_fe5m2_rn) { TEST(NumericConversion, f32_to_fe5m2_rn) {
int const kN = 1; int const kN = 1;
using Source = float; using Source = float;
const char source_name[] = "float";
using Destination = cutlass::float_e5m2_t; using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f32_to_fe5m2_rn_array) { TEST(NumericConversion, f32_to_fe5m2_rn_array) {
int const kN = 27; int const kN = 27;
using Source = float; using Source = float;
const char source_name[] = "float";
using Destination = cutlass::float_e5m2_t; using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f16_to_fe4m3_rn) { TEST(NumericConversion, f16_to_fe4m3_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::half_t; using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = cutlass::float_e4m3_t; using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f16_to_fe4m3_rn_array) { TEST(NumericConversion, f16_to_fe4m3_rn_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::half_t; using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = cutlass::float_e4m3_t; using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f16_to_fe5m2_rn) { TEST(NumericConversion, f16_to_fe5m2_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::half_t; using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = cutlass::float_e5m2_t; using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, f16_to_fe5m2_rn_array) { TEST(NumericConversion, f16_to_fe5m2_rn_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::half_t; using Source = cutlass::half_t;
const char source_name[] = "half_t";
using Destination = cutlass::float_e5m2_t; using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, bf16_to_fe4m3_rn) { TEST(NumericConversion, bf16_to_fe4m3_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::bfloat16_t; using Source = cutlass::bfloat16_t;
const char source_name[] = "bfloat16_t";
using Destination = cutlass::float_e4m3_t; using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, bf16_to_fe4m3_rn_array) { TEST(NumericConversion, bf16_to_fe4m3_rn_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::bfloat16_t; using Source = cutlass::bfloat16_t;
const char source_name[] = "bfloat16_t";
using Destination = cutlass::float_e4m3_t; using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, bf16_to_fe5m2_rn) { TEST(NumericConversion, bf16_to_fe5m2_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::bfloat16_t; using Source = cutlass::bfloat16_t;
const char source_name[] = "bfloat16_t";
using Destination = cutlass::float_e5m2_t; using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, bf16_to_fe5m2_rn_array) { TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::bfloat16_t; using Source = cutlass::bfloat16_t;
const char source_name[] = "bfloat16_t";
using Destination = cutlass::float_e5m2_t; using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
@ -216,36 +250,46 @@ TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
TEST(NumericConversion, fe4m3_to_fe5m2_rn) { TEST(NumericConversion, fe4m3_to_fe5m2_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::float_e4m3_t; using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::float_e5m2_t; using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe4m3_to_fe5m2_array) { TEST(NumericConversion, fe4m3_to_fe5m2_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::float_e4m3_t; using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::float_e5m2_t; using Destination = cutlass::float_e5m2_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e5m2_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe5m2_to_fe4m3_rn) { TEST(NumericConversion, fe5m2_to_fe4m3_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::float_e5m2_t; using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::float_e4m3_t; using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe5m2_to_fe4m3_array) { TEST(NumericConversion, fe5m2_to_fe4m3_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::float_e5m2_t; using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::float_e4m3_t; using Destination = cutlass::float_e4m3_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float_e4m3_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe4m3_to_f32_rn) { TEST(NumericConversion, fe4m3_to_f32_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::float_e4m3_t; using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = float; using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
@ -254,78 +298,100 @@ TEST(NumericConversion, f32x8_to_s8x8_rn) {
int const kN = 8; int const kN = 8;
using Source = float; using Source = float;
const char source_name[] = "float";
using Destination = int8_t; using Destination = int8_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "int8_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe4m3_to_f32_array) { TEST(NumericConversion, fe4m3_to_f32_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::float_e4m3_t; using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = float; using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe5m2_to_f32_array) { TEST(NumericConversion, fe5m2_to_f32_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::float_e5m2_t; using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = float; using Destination = float;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "float";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe4m3_to_f16_rn) { TEST(NumericConversion, fe4m3_to_f16_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::float_e4m3_t; using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::half_t; using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe4m3_to_f16_array) { TEST(NumericConversion, fe4m3_to_f16_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::float_e4m3_t; using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::half_t; using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe5m2_to_f16_rn) { TEST(NumericConversion, fe5m2_to_f16_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::float_e5m2_t; using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::half_t; using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe5m2_to_f16_array) { TEST(NumericConversion, fe5m2_to_f16_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::float_e5m2_t; using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::half_t; using Destination = cutlass::half_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "half_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe4m3_to_bf16_rn) { TEST(NumericConversion, fe4m3_to_bf16_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::float_e4m3_t; using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::bfloat16_t; using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "bfloat16_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe4m3_to_bf16_array) { TEST(NumericConversion, fe4m3_to_bf16_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::float_e4m3_t; using Source = cutlass::float_e4m3_t;
const char source_name[] = "float_e4m3_t";
using Destination = cutlass::bfloat16_t; using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "bfloat16_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe5m2_to_bf16_rn) { TEST(NumericConversion, fe5m2_to_bf16_rn) {
int const kN = 1; int const kN = 1;
using Source = cutlass::float_e5m2_t; using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::bfloat16_t; using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "bfloat16_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
TEST(NumericConversion, fe5m2_to_bf16_array) { TEST(NumericConversion, fe5m2_to_bf16_array) {
int const kN = 27; int const kN = 27;
using Source = cutlass::float_e5m2_t; using Source = cutlass::float_e5m2_t;
const char source_name[] = "float_e5m2_t";
using Destination = cutlass::bfloat16_t; using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test<Destination, Source, kN>(); const char dest_name[] = "bfloat16_t";
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -36,6 +36,7 @@ cutlass_test_unit_add_executable(
compare.cpp compare.cpp
complement.cpp complement.cpp
composition.cpp composition.cpp
constant_arithmetic.cpp
core_unit.cpp core_unit.cpp
inverse_left.cpp inverse_left.cpp
inverse_right.cpp inverse_right.cpp

View File

@ -0,0 +1,106 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include "cutlass_unit_test.h"
#include <cutlass/trace.h>
#include <cute/swizzle.hpp>
TEST(CuTe_core, ConstantArithmetic) {
using namespace cute;
constexpr cute::integral_constant<uint32_t, 0> uzero{};
// This extra test exists historically as part of the diagnosis
// of a possible Clang 14 bug. However, it's a nice test for
// cute::integral_constant's arithmetic operators, so it's saved here.
// It also demonstrates how to work with cute::integral_constant
// and lambda captures. Microsoft Visual Studio ("MSVC") tends to
// disagree with other compilers about the meaning of decltype
// for variables captured by reference. MSVC and GCC 8.3.0
// also tend to disagree with other compilers (and other GCC versions)
// about whether expressions involving such variables
// are constant expressions.
//
// A typical CuTe idiom is to do lambda captures by reference [&].
// This test changes them to capture by value, except for
// the innermost lambda's capture of S1, which is by reference.
// The point is to show that MSVC and GCC 8 have issues with this
// that other compilers do not. For example,
//
// 1. MSVC needs remove_cvref_t around decltype(S1)
// in order to access decltype(S1)::value, and
// 2. MSVC and GCC 8.3.0 both report a build error with S1()
// (that is, calling operator() on S1, which returns the
// same thing as S1.value).
//
// The reason for (2) is that neither compiler thinks
// that S1() is a constant expression.
//
// This leaves S1.value as the most concise portable expression
// for the "value" member of a cute::integral_constant.
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero](auto S0) {
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0](auto F0) {
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0](auto S1) {
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0,&S1](auto F1) {
static_assert((decltype(S0)::value & decltype(F0)::value) == decltype(S0 & F0)::value);
// Using S1.value means you don't have to use remove_cvref_t
// with a captured-by-reference variable.
static_assert((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) == decltype(S1 & F1)::value);
static_assert((S1.value & decltype(F1)::value) == decltype(S1 & F1)::value);
// S1() _should_ work, but does not with Visual Studio 2022,
// which emits C2131 ("expression did not evaluate to a constant").
// It also does not with GCC 8.3.0, which emits an error with messages
// "non-constant condition for static assertion" and
// "'this' is not a constant expression."
//
//static_assert((S1() & decltype(F1)::value) == decltype(S1 & F1)::value);
static_assert(decltype((S0 & F0) != uzero)::value == ((decltype(S0)::value & decltype(F0)::value) != 0));
static_assert(decltype((S1 & F1) != uzero)::value == ((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0));
static_assert(decltype((S1 & F1) != uzero)::value == ((S1.value & decltype(F1)::value) != 0));
constexpr bool left = decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value;
constexpr bool right =
((decltype(S0)::value & decltype(F0)::value) != 0) ||
((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0);
constexpr bool right2 =
((decltype(S0)::value & decltype(F0)::value) != 0) ||
((S1.value & decltype(F1)::value) != 0);
static_assert(right == right2);
static_assert(left == right);
constexpr bool left2 = decltype((S0 & F0) != uzero)::value || decltype((S1 & F1) != uzero)::value;
static_assert(left == left2);
});
});
});
});
}

View File

@ -31,9 +31,58 @@
#include "cutlass_unit_test.h" #include "cutlass_unit_test.h"
// C<uint32_t(something)>::value_type is not uint32_t for GCC 7.5.0.
// This test is thus disabled for GCC < 8.
#if defined(__GNUC__) && (__GNUC__ < 8)
#include <cutlass/trace.h> #include <cutlass/trace.h>
#include <cute/swizzle.hpp> #include <cute/swizzle.hpp>
namespace { // (anonymous)
// This function exists to work around a Clang 14 issue, in which
// the compiler tries to instantiate code that lives inside the
// "else" branch of an "if constexpr," even when the "else" branch
// is false. That triggers a spurious static_assert in MixedBits.
// The work-around is to make the body of the "else" branch a
// function, rather than leaving it in line.
//
// Some compilers strangely deduce the first two terms of
// make_integer_sequence<uint32_t, 8> as C<false> and C<true>, and
// the remaining terms as C<2>, C<3>, etc. Making this function take
// cute::integral_constant<uint32_t, S0_value>, etc. doesn't work
// with those compilers.
template<class S0_type, S0_type S0_value,
class F0_type, F0_type F0_value,
class S1_type, S1_type S1_value,
class F1_type, F1_type F1_value>
void clang14_workaround(cute::integral_constant<S0_type, S0_value>,
cute::integral_constant<F0_type, F0_value>,
cute::integral_constant<S1_type, S1_value>,
cute::integral_constant<F1_type, F1_value>)
{
constexpr cute::C<static_cast<uint32_t>(S0_value)> S0{};
constexpr cute::C<static_cast<uint32_t>(F0_value)> F0{};
constexpr cute::C<static_cast<uint32_t>(S1_value)> S1{};
constexpr cute::C<static_cast<uint32_t>(F1_value)> F1{};
for (uint32_t d0 = 0; d0 < 8; ++d0) {
if ((d0 & F0) != d0) { continue; } // Skip repeats
for (uint32_t d1 = 0; d1 < 8; ++d1) {
if ((d1 & F1) != d1) { continue; } // Skip repeats
auto m0 = make_mixed_bits(S0, d0, F0);
auto m1 = make_mixed_bits(S1, d1, F1);
//print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n");
EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1));
//print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n");
EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1));
//print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n");
EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1));
}
}
}
} // namespace (anonymous)
TEST(CuTe_core, MixedBits) { TEST(CuTe_core, MixedBits) {
using namespace cute; using namespace cute;
@ -48,23 +97,21 @@ TEST(CuTe_core, MixedBits) {
} else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) { } else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) {
return; return;
} else { } else {
for (uint32_t d0 = 0; d0 < 8; ++d0) { clang14_workaround(S0, F0, S1, F1);
if ((d0 & F0) != d0) { continue; } // Skip repeats
for (uint32_t d1 = 0; d1 < 8; ++d1) {
if ((d1 & F1) != d1) { continue; } // Skip repeats
auto m0 = make_mixed_bits(S0, d0, F0);
auto m1 = make_mixed_bits(S1, d1, F1);
//print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n");
EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1));
//print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n");
EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1));
//print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n");
EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1));
}
}
} }
}); });
}); });
}); });
}); });
} }
TEST(CuTe_core, MakeIntegerSequence) {
cute::for_each(cute::make_integer_sequence<uint32_t, 8>{}, [](auto c) {
using c_type = decltype(c);
constexpr auto c_value = c_type::value;
using expected_type = cute::integral_constant<uint32_t, c_value>;
static_assert(cute::is_same_v<c_type, expected_type>);
});
}
#endif // defined(__GNUC__) && (__GNUC__ < 8)

View File

@ -1,5 +1,5 @@
/*************************************************************************************************** /***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause * SPDX-License-Identifier: BSD-3-Clause
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without

View File

@ -1,5 +1,5 @@
/*************************************************************************************************** /***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause * SPDX-License-Identifier: BSD-3-Clause
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without

View File

@ -32,6 +32,7 @@
\brief Tests that the stream-K scheduler covers the entire problem space. \brief Tests that the stream-K scheduler covers the entire problem space.
*/ */
#include "cutlass/cluster_launch.hpp"
#include "cutlass/kernel_hardware_info.hpp" #include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp"
#include "cutlass/util/device_memory.h" #include "cutlass/util/device_memory.h"
@ -39,6 +40,10 @@
#include "../../common/cutlass_unit_test.h" #include "../../common/cutlass_unit_test.h"
// Grids are launched with clusters enabled in these tests,
// so the CTK version must support cluster launching.
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
using namespace cute; using namespace cute;
using ProblemShape_MNKL = Shape<int, int, int, int>; using ProblemShape_MNKL = Shape<int, int, int, int>;
@ -60,7 +65,7 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape
while (work_tile_info.is_valid_tile) { while (work_tile_info.is_valid_tile) {
// Increment counters to indicate coverage // Increment counters to indicate coverage
auto tile_idx = Scheduler::output_tile_index(params, work_tile_info); auto tile_idx = Scheduler::output_tile_index(params, work_tile_info);
auto offset = tile_idx * params.k_iter_per_tile_ + work_tile_info.K_idx; auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx;
for (auto i = 0; i < work_tile_info.k_tile_count; ++i) { for (auto i = 0; i < work_tile_info.k_tile_count; ++i) {
// Use atomicAdd because the visit counters are shared by multiple thread blocks. // Use atomicAdd because the visit counters are shared by multiple thread blocks.
// While having more than one block increment the same counter indicates failure, // While having more than one block increment the same counter indicates failure,
@ -103,7 +108,7 @@ test_scheduler(
// Allocate counters indicating the number of times each k iteration of each output tile has been visited // Allocate counters indicating the number of times each k iteration of each output tile has been visited
auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
auto total_counters = blk_m * blk_n * blk_l * params.k_iter_per_tile_; auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_;
cutlass::DeviceAllocation<int> visit_counters(total_counters); cutlass::DeviceAllocation<int> visit_counters(total_counters);
// Initialize counters to zero // Initialize counters to zero
@ -118,12 +123,55 @@ test_scheduler(
// Set up the grid for the problem // Set up the grid for the problem
dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args); dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args);
// Set up cluster and cluster launch. This is needed even for this simple kernel because
// the SM90 scheduler needs to be able to query the CTA id within a cluster, which requires
// explicitly launching with clusters.
dim3 cluster{
static_cast<uint32_t>(cute::get<0>(ClusterShape{})),
static_cast<uint32_t>(cute::get<1>(ClusterShape{})),
static_cast<uint32_t>(cute::get<2>(ClusterShape{}))
};
cudaLaunchConfig_t launch_config;
launch_config.gridDim = grid;
launch_config.blockDim = {1, 1, 1};
launch_config.dynamicSmemBytes = 0;
launch_config.stream = NULL;
cudaLaunchAttribute launch_attribute[1];
launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
launch_attribute[0].val.clusterDim.x = cluster.x;
launch_attribute[0].val.clusterDim.y = cluster.y;
launch_attribute[0].val.clusterDim.z = cluster.z;
launch_config.attrs = launch_attribute;
launch_config.numAttrs = 1;
void const* kernel = (void const*) run_scheduler<Scheduler, TileShape, ClusterShape>;
int* counters_ptr = visit_counters.get();
void* kernel_params[] = {
&counters_ptr,
&params,
&tile_shape,
&cluster_shape,
&problem_shape_mnkl
};
// Run the scheduler to completion and log visits to each k iteration // Run the scheduler to completion and log visits to each k iteration
run_scheduler<Scheduler, TileShape, ClusterShape><<<grid, 1>>>( err = cudaLaunchKernelExC(&launch_config, kernel, kernel_params);
visit_counters.get(), params, tile_shape, cluster_shape, problem_shape_mnkl);
if (err != cudaSuccess) {
std::cerr << __FILE__ << ":" << __LINE__
<< " cudaLaunchKernelExC failed with error: "
<< cudaGetErrorString(err) << std::endl;
return false;
}
err = cudaDeviceSynchronize(); err = cudaDeviceSynchronize();
if (err != cudaSuccess) { if (err != cudaSuccess) {
std::cerr << __FILE__ << ":" << __LINE__ << " scheduler kernel failed with error: " << cudaGetErrorString(err) << std::endl; std::cerr << __FILE__ << ":" << __LINE__
<< " scheduler kernel failed with error: "
<< cudaGetErrorString(err) << std::endl;
return false; return false;
} }
@ -143,11 +191,11 @@ test_scheduler(
<< " and grid size " << grid.x << "x" << " and grid size " << grid.x << "x"
<< grid.y << "x" << grid.z << grid.y << "x" << grid.z
<< " splits=" << params.splits_ << " splits=" << params.splits_
<< " k_iter=" << params.k_iter_per_tile_ << " k_iter=" << params.k_tiles_per_output_tile_
<< " big_units=" << params.big_units_ << " big_units=" << params.big_units_
<< " sk_tiles=" << params.sk_tiles_ << " sk_tiles=" << params.sk_tiles_
<< " sk_units=" << params.sk_units_ << " sk_units=" << params.sk_units_
<< " k_iter_per_sk_unit=" << params.k_iter_per_sk_unit_ << std::endl; << " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl;
std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl; std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl;
return false; return false;
} }
@ -274,4 +322,6 @@ TEST(SM90_Device_Gemm_stream_k_scheduler, 128x128x64_2x1x1) {
EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114)); EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114));
} }
#endif // defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -179,7 +179,7 @@ class GemmOperation:
''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
if self.arch >= 90: if self.arch >= 90:
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}{t}" kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{t}{k}{e}"
return kernel_name_template.format( return kernel_name_template.format(
p = self.prefix, p = self.prefix,
ar = self.arch, ar = self.arch,
@ -194,9 +194,9 @@ class GemmOperation:
l = self.tile_description.stages, l = self.tile_description.stages,
s = self.layout_name_3x(), s = self.layout_name_3x(),
al = str(max(self.A.alignment, self.B.alignment)), al = str(max(self.A.alignment, self.B.alignment)),
t = TileSchedulerSuffixes[self.tile_scheduler],
k = self.kernel_schedule_name_3x(), k = self.kernel_schedule_name_3x(),
e = self.epilogue_schedule_name_3x(), e = self.epilogue_schedule_name_3x())
t = TileSchedulerSuffixes[self.tile_scheduler])
else: else:
threadblock = self.tile_description.procedural_name() threadblock = self.tile_description.procedural_name()
return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
@ -661,8 +661,7 @@ using ${operation_name}_mainloop =
${element_accumulator}, ${element_accumulator},
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>, cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>, cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
cutlass::gemm::collective::StageCountAutoCarveout< ${stages},
sizeof(typename ${operation_name}_epilogue::SharedStorage)>,
${kernel_schedule} ${kernel_schedule}
>::CollectiveOp; >::CollectiveOp;
@ -697,7 +696,7 @@ ${compile_guard_end}
if operation.tile_description.stages > 0: if operation.tile_description.stages > 0:
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
else: else:
stage_count_string = "cutlass::gemm::collective::StageCountAuto" stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage)>"
warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)] warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)]
instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \ instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \

View File

@ -4218,6 +4218,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
layout[2][1] = 8 layout[2][1] = 8
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules)
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
# persistent kernels with TMA epilogues # persistent kernels with TMA epilogues
if CudaToolkitVersionSatisfies(cuda_version, 12, 1): if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed,