Updates for 3.2 release (#1065)
This commit is contained in:
parent
27de343535
commit
a88c41cf8d
@ -2,10 +2,14 @@
|
||||
|
||||
## 2023
|
||||
|
||||
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, and Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
|
||||
- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
|
||||
|
||||
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
|
||||
|
||||
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
|
||||
|
||||
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
|
||||
|
||||
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
|
||||
|
||||
## 2022
|
||||
|
||||
@ -71,7 +71,7 @@ struct Options {
|
||||
/// Prints the usage statement.
|
||||
std::ostream & print_usage(std::ostream &out) const {
|
||||
|
||||
out << "52_fp8_hopper_warp_specialized_gemm\n\n"
|
||||
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
|
||||
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
|
||||
<< "Options:\n\n"
|
||||
<< " --help If specified, displays this usage statement\n\n"
|
||||
@ -93,7 +93,7 @@ struct Options {
|
||||
|
||||
out
|
||||
<< "\n\nExamples:\n\n"
|
||||
<< "$ " << "52_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -54,6 +54,13 @@ struct SyncthreadsSync {
|
||||
}
|
||||
};
|
||||
|
||||
struct SyncwarpSync {
|
||||
CUTLASS_DEVICE
|
||||
static void sync() {
|
||||
__syncwarp();
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
int ThreadCount,
|
||||
int BarrierId
|
||||
@ -311,6 +318,60 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads)
|
||||
* via an API that mirrors that of NamedBarrierManager
|
||||
*
|
||||
* @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization
|
||||
**/
|
||||
template <
|
||||
class Synchronizer,
|
||||
uint32_t ThreadCount_
|
||||
>
|
||||
struct SyncManager {
|
||||
|
||||
// Number of threads participating in the barrier
|
||||
static constexpr uint32_t ThreadCount = ThreadCount_;
|
||||
|
||||
using BarrierSync = cutlass::GenericBarrier<Synchronizer>;
|
||||
|
||||
// Underlying type used by all barriers for synchronization.
|
||||
using T = typename BarrierSync::T;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static
|
||||
void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) {
|
||||
BarrierSync::wait_lt_helper(lock_ptr, thread_idx, flag_idx, count);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
||||
BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
||||
BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) {
|
||||
BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) {
|
||||
BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -67,7 +67,7 @@ sm90_get_tma_dispatch_policy() {
|
||||
|
||||
constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{}));
|
||||
constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v<Schedule> ? 256 : 128);
|
||||
constexpr int ReuseSmemC = sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>;
|
||||
constexpr int ReuseSmemC = (sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>) && (sizeof_bits_v<ElementD> > 8);
|
||||
constexpr int StagesD = 2;
|
||||
constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles;
|
||||
|
||||
@ -98,7 +98,7 @@ sm90_get_epilogue_smem_swizzle_layout_atom() {
|
||||
}
|
||||
|
||||
// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
|
||||
template <class Element, class EpilogueTileType, class Schedule>
|
||||
template <class ElementD, class EpilogueTileType, class Schedule>
|
||||
constexpr auto
|
||||
sm90_compute_tile_shape_or_override() {
|
||||
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||
@ -107,8 +107,13 @@ sm90_compute_tile_shape_or_override() {
|
||||
return Shape<_128,_32>{};
|
||||
}
|
||||
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
|
||||
if constexpr (sizeof_bits_v<ElementD> == 8) {
|
||||
return Shape<_64,_64>{};
|
||||
}
|
||||
else {
|
||||
return Shape<_64,_32>{};
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule.");
|
||||
}
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
|
||||
namespace cutlass::gemm::kernel::detail {
|
||||
|
||||
@ -205,18 +206,14 @@ public:
|
||||
|
||||
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
|
||||
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
|
||||
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
|
||||
// like blockIdx and gridDim, with __CUDA_ARCH__.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
cluster_minor_offset = blockIdx.x;
|
||||
cluster_minor_offset = cta_m_in_cluster;
|
||||
}
|
||||
else {
|
||||
cluster_minor_offset = blockIdx.y;
|
||||
cluster_minor_offset = cta_n_in_cluster;
|
||||
}
|
||||
#else
|
||||
CUTLASS_ASSERT(false && "This line should never be reached");
|
||||
#endif
|
||||
|
||||
uint64_t cluster_idx_minor, cluster_idx_major;
|
||||
|
||||
|
||||
@ -141,7 +141,7 @@ public:
|
||||
uint32_t splits_ = 1;
|
||||
|
||||
// Number of tiled k iterations required to compute a single output tile.
|
||||
uint32_t k_iter_per_tile_ = 0;
|
||||
uint32_t k_tiles_per_output_tile_ = 0;
|
||||
|
||||
// Number of stream-K or split-K work units that compute an extra k iteration.
|
||||
// This is done to handle residuals in dividing up the k iteration space.
|
||||
@ -160,7 +160,7 @@ public:
|
||||
|
||||
// Number of tiled k iterations computed by each stream-K work unit. This
|
||||
// can potentially cover more than one output tile.
|
||||
uint32_t k_iter_per_sk_unit_ = 0;
|
||||
uint32_t k_tiles_per_sk_unit_ = 0;
|
||||
};
|
||||
|
||||
// Sink scheduler params as a member
|
||||
@ -189,9 +189,9 @@ public:
|
||||
|
||||
uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l;
|
||||
|
||||
// Number of k iterations each tile computes (this is just the number of k iterations
|
||||
// in the problem's K dimension)
|
||||
uint32_t k_iter_per_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) / cute::size<2>(tile_shape);
|
||||
// Number of k tile iterations in each output tile
|
||||
uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) /
|
||||
cute::size<2>(tile_shape);
|
||||
|
||||
UnderlyingArguments underlying_args;
|
||||
underlying_args.max_swizzle_size = 1;
|
||||
@ -216,11 +216,11 @@ public:
|
||||
// splits is almost certainly nonnegative here (e.g., hw_info.sm_count,
|
||||
// despite being an int, is a count), so it can safely be converted to unsigned
|
||||
// in the comparison to avoid a signed-unsigned comparison warning-as-error.
|
||||
splits = static_cast<decltype(k_iter_per_tile)>(splits) > k_iter_per_tile ? k_iter_per_tile : splits;
|
||||
splits = static_cast<decltype(k_tiles_per_output_tile)>(splits) > k_tiles_per_output_tile ? k_tiles_per_output_tile : splits;
|
||||
|
||||
return get_params_basic(
|
||||
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
||||
splits, k_iter_per_tile, reduction_workspace);
|
||||
splits, k_tiles_per_output_tile, reduction_workspace);
|
||||
}
|
||||
|
||||
// Calculate the maximum number of blocks from clusters of shape cluster_shape that we
|
||||
@ -229,7 +229,7 @@ public:
|
||||
uint64_t ctas_per_wave = grid.x * grid.y;
|
||||
|
||||
// The number of output tiles to be computed in stream-K and data-parallel fashion, respectively.
|
||||
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave);
|
||||
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile);
|
||||
uint64_t dp_tiles = output_tiles - sk_tiles;
|
||||
|
||||
// Calculate the number of work units covering the data-parallel and stream-K tiles.
|
||||
@ -243,7 +243,7 @@ public:
|
||||
uint64_t dp_units = dp_tiles;
|
||||
|
||||
// Number of k iterations computed by the stream-K units as a whole
|
||||
uint64_t k_iter_sk_total = k_iter_per_tile * sk_tiles;
|
||||
uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles;
|
||||
|
||||
// If there are stream-K tiles to compute and a sufficiently large number of k iterations
|
||||
// across them, they will be covered by a single wave of persistent threadblocks. Thus, there
|
||||
@ -255,7 +255,7 @@ public:
|
||||
|
||||
// Calculate the number of stream-K units that would be needed if each stream-K unit
|
||||
// computed the minimum allowable k iterations. Truncate this to be in units of clusters.
|
||||
uint64_t min_sized_sk_units = (k_iter_sk_total / min_iters_per_sk_unit_);
|
||||
uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_);
|
||||
min_sized_sk_units = (min_sized_sk_units / cute::size(cluster_shape)) * cute::size(cluster_shape);
|
||||
|
||||
uint64_t sk_units = min(ctas_per_wave, min_sized_sk_units);
|
||||
@ -264,7 +264,7 @@ public:
|
||||
// Short circuit to basic data-parallel decomposition
|
||||
return get_params_basic(
|
||||
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
||||
1, k_iter_per_tile, reduction_workspace);
|
||||
1, k_tiles_per_output_tile, reduction_workspace);
|
||||
}
|
||||
|
||||
// If the number of stream-K units is a multiple of the number of stream-K tiles, then
|
||||
@ -274,24 +274,24 @@ public:
|
||||
uint32_t sk_splits = static_cast<uint32_t>(sk_units / sk_tiles);
|
||||
return get_params_basic(
|
||||
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
||||
sk_splits, k_iter_per_tile, reduction_workspace);
|
||||
sk_splits, k_tiles_per_output_tile, reduction_workspace);
|
||||
}
|
||||
|
||||
// Number of k iterations computed per stream-K units
|
||||
uint64_t k_iter_per_sk_unit = k_iter_sk_total / sk_units;
|
||||
uint64_t k_tiles_per_sk_unit = k_tiles_sk_total / sk_units;
|
||||
|
||||
// Number of stream-K units that need to compute extra iterations in order to cover
|
||||
// the residual k iterations. This assumes that each such unit computes one additional
|
||||
// iteration.
|
||||
uint64_t sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units);
|
||||
uint64_t sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units);
|
||||
|
||||
// The division below is guaranteed to be exact because sk_big_units is guaranteed
|
||||
// to be a multiple of cluster_size (cute::size(cluster_shape)). This is useful because
|
||||
// it allows us to use a block's linearized cluster ID to determine whether it is
|
||||
// a big block. The reasoning behind this guarnatee is explained as follows:
|
||||
// sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units);
|
||||
// sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units);
|
||||
//
|
||||
// - k_iter_sk_total is a multiple of cluster_size because it is the product
|
||||
// - k_tiles_sk_total is a multiple of cluster_size because it is the product
|
||||
// of number of tail tiles and the number of k iterations per tile. Because
|
||||
// both the number of output tiles and number of available SMs are rounded
|
||||
// to be multiples of cluster shape, the number of tail tiles
|
||||
@ -313,12 +313,12 @@ public:
|
||||
underlying_params.raster_order_,
|
||||
cluster_shape,
|
||||
1, // Static k-splitting factor. Unused for stream-K.
|
||||
k_iter_per_tile,
|
||||
k_tiles_per_output_tile,
|
||||
static_cast<uint32_t>(sk_big_units_per_cluster),
|
||||
reduction_workspace,
|
||||
sk_tiles,
|
||||
static_cast<uint32_t>(sk_units),
|
||||
static_cast<uint32_t>(k_iter_per_sk_unit)
|
||||
static_cast<uint32_t>(k_tiles_per_sk_unit)
|
||||
};
|
||||
}
|
||||
|
||||
@ -338,106 +338,33 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work() const {
|
||||
return get_current_work_for_linear_idx(current_work_linear_idx_);
|
||||
return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work_for_linear_idx(uint64_t linear_idx) const {
|
||||
if (linear_idx >= scheduler_params.units_per_problem_) {
|
||||
static WorkTileInfo
|
||||
get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) {
|
||||
if (linear_idx >= params.units_per_problem_) {
|
||||
// Invalid work. Return an empty result.
|
||||
return {0, 0, 0, 0, false, 0};
|
||||
}
|
||||
|
||||
// Determine whether this work unit is a data-parallel or stream-K work unit
|
||||
bool is_stream_k_unit = linear_idx < scheduler_params.sk_units_;
|
||||
bool is_stream_k_unit = linear_idx < params.sk_units_;
|
||||
|
||||
bool is_split_k = scheduler_params.splits_ > 1;
|
||||
bool is_split_k = params.splits_ > 1;
|
||||
|
||||
// Bypass the stream-K scheduling logic for basic data-parallel or split-K work
|
||||
if (is_split_k || !is_stream_k_unit) {
|
||||
// The linearized ID space is in terms of work units, rather than tiles. However,
|
||||
// to compute the correct block offset for a data-parallel tile, we must convert
|
||||
// the current ID to the data-parallel tile it corresponds to. Each data-parallel
|
||||
// unit maps to a single data-parallel tile, but each stream-K unit can map to more
|
||||
// than one tile. Thus, we must offset the work-unit ID among the data-parallel units
|
||||
// by the total number of output tiles that will be computed by stream-K units.
|
||||
//
|
||||
// The logic below also works for the split-K case, in which sk_units_ and sk_tiles_
|
||||
// are each 0.
|
||||
uint64_t linear_work_idx = linear_idx - scheduler_params.sk_units_ + scheduler_params.sk_tiles_;
|
||||
|
||||
// Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices
|
||||
uint64_t work_idx_l, remainder;
|
||||
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_work_idx);
|
||||
|
||||
uint64_t work_idx_k = 0;
|
||||
if (is_split_k) {
|
||||
scheduler_params.divmod_k_(work_idx_k, remainder, remainder);
|
||||
// Bypass the stream-K scheduling logic for basic data-parallel or split-K work
|
||||
return set_non_stream_k_work(linear_idx, params, is_split_k);
|
||||
}
|
||||
|
||||
uint64_t cta_per_grid_dim, dontcare;
|
||||
scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
|
||||
|
||||
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
|
||||
cta_per_grid_dim,
|
||||
scheduler_params.divmod_cluster_shape_major_,
|
||||
scheduler_params.divmod_cluster_shape_minor_,
|
||||
scheduler_params.divmod_cluster_blk_major_,
|
||||
scheduler_params.log_swizzle_size_,
|
||||
scheduler_params.raster_order_);
|
||||
|
||||
bool is_final_split = (work_idx_k == scheduler_params.splits_ - 1);
|
||||
|
||||
uint32_t k_iter = scheduler_params.k_iter_per_tile_;
|
||||
if (is_split_k) {
|
||||
// Determine the number of iterations and starting iteration of this split.
|
||||
// Doing so requires accounting for residual iterations, which are handled
|
||||
// by the first big_units_ splits (with big_units_ = tiles % sm_count).
|
||||
|
||||
// Offsets for "normal" units. No additional k iterations are performed,
|
||||
// and big_units_ "big" units preceded us, each of which performed one
|
||||
// additional iteration. Thus, we must increase our split starting offset
|
||||
// by big_units_.
|
||||
int additional_k_iter = 0;
|
||||
int split_start_offset = scheduler_params.big_units_;
|
||||
|
||||
if (work_idx_k < scheduler_params.big_units_) {
|
||||
// Offsets for "big" units. One additional k iteration is performed,
|
||||
// and each split preceding us was a big unit, so we must increase
|
||||
// our split starting offset by our split ID (work_idx_k).
|
||||
additional_k_iter = 1;
|
||||
split_start_offset = work_idx_k;
|
||||
}
|
||||
|
||||
// Set up k iteration count and split starting iteration assuming the
|
||||
// iteration space is evenly split.
|
||||
k_iter /= scheduler_params.splits_;
|
||||
work_idx_k *= k_iter;
|
||||
|
||||
// Apply any fixup needed to handle residuals
|
||||
work_idx_k += split_start_offset;
|
||||
k_iter += additional_k_iter;
|
||||
}
|
||||
|
||||
return {
|
||||
work_idx_m,
|
||||
work_idx_n,
|
||||
static_cast<int32_t>(work_idx_k),
|
||||
static_cast<int32_t>(work_idx_l),
|
||||
true,
|
||||
scheduler_params.k_iter_per_tile_,
|
||||
k_iter,
|
||||
k_iter, // remaining iterations
|
||||
is_final_split
|
||||
};
|
||||
}
|
||||
|
||||
else {
|
||||
// This is a stream-K work unit
|
||||
WorkTileInfo work_tile_info;
|
||||
set_stream_k_work(linear_idx, work_tile_info, /*new_unit = */ true);
|
||||
set_stream_k_work(params, linear_idx, work_tile_info, /*new_unit = */ true);
|
||||
return work_tile_info;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns whether the current work_tile_info passed in should continue to be used. This
|
||||
// occurs only in the stream-K decomposition with stream-K work units, which encompass
|
||||
@ -446,13 +373,24 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
bool
|
||||
continue_current_work(WorkTileInfo& work_tile_info) const {
|
||||
return continue_current_work_for_linear_idx(
|
||||
current_work_linear_idx_, work_tile_info, scheduler_params);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE static
|
||||
bool
|
||||
continue_current_work_for_linear_idx(
|
||||
uint64_t linear_idx,
|
||||
WorkTileInfo& work_tile_info,
|
||||
Params const& params) {
|
||||
|
||||
work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count;
|
||||
|
||||
if (work_tile_info.k_tile_remaining == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
set_stream_k_work(current_work_linear_idx_, work_tile_info, /* new_unit = */ false);
|
||||
set_stream_k_work(params, linear_idx, work_tile_info, /* new_unit = */ false);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -495,6 +433,14 @@ public:
|
||||
/*truncate_by_problem_size=*/false);
|
||||
}
|
||||
|
||||
// Returns whether fixup is needed for `work_tile_info`.
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool
|
||||
requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) {
|
||||
// Fixup is not needed for data-parallel tiles
|
||||
return work_tile_info.k_tile_count != params.k_tiles_per_output_tile_;
|
||||
}
|
||||
|
||||
// Performs the reduction across splits for a given output tile.
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE
|
||||
@ -505,13 +451,25 @@ public:
|
||||
FrgTensorC& accumulators,
|
||||
uint32_t num_barriers,
|
||||
uint32_t barrier_idx) {
|
||||
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, 2>;
|
||||
return fixup_helper<FrgTensorC, BarrierManager>(
|
||||
params, work_tile_info, accumulators, num_barriers, barrier_idx);
|
||||
}
|
||||
|
||||
// Helper for performing the reduction across splits for a given output tile.
|
||||
template <class FrgTensorC, class BarrierManager>
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
fixup_helper(
|
||||
Params const& params,
|
||||
WorkTileInfo const& work_tile_info,
|
||||
FrgTensorC& accumulators,
|
||||
uint32_t num_barriers,
|
||||
uint32_t barrier_idx) {
|
||||
|
||||
using ElementAccumulator = typename FrgTensorC::value_type;
|
||||
|
||||
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, 2>;
|
||||
|
||||
if (work_tile_info.k_tile_count == params.k_iter_per_tile_) {
|
||||
// Fixup is not needed for data-parallel tiles
|
||||
if (!requires_fixup(params, work_tile_info)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -619,21 +577,23 @@ public:
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||
|
||||
uint64_t cta_per_grid_dim;
|
||||
uint64_t cluster_dim_idx;
|
||||
if (params.raster_order_ == RasterOrder::AlongN) {
|
||||
uint64_t block_idx_m = (work_tile_info.M_idx - blockIdx.x) / gridDim.x;
|
||||
uint64_t block_idx_m = (work_tile_info.M_idx - cta_m_in_cluster) / cute::size<0>(params.cluster_shape_);
|
||||
uint64_t block_idx_n = work_tile_info.N_idx;
|
||||
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
|
||||
params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n;
|
||||
cluster_dim_idx = blockIdx.x;
|
||||
cluster_dim_idx = cta_m_in_cluster;
|
||||
}
|
||||
else {
|
||||
uint64_t block_idx_m = work_tile_info.M_idx;
|
||||
uint64_t block_idx_n = (work_tile_info.N_idx - blockIdx.y) / gridDim.y;
|
||||
uint64_t block_idx_n = (work_tile_info.N_idx - cta_n_in_cluster) / cute::size<1>(params.cluster_shape_);
|
||||
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
|
||||
params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m;
|
||||
cluster_dim_idx = blockIdx.y;
|
||||
cluster_dim_idx = cta_n_in_cluster;
|
||||
}
|
||||
|
||||
uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim;
|
||||
@ -715,7 +675,7 @@ private:
|
||||
|
||||
// Construct a layout for the indexed tensor. The main purpose of this new layout is to
|
||||
// override the k extent to support cases in which the split computes a number of iterations
|
||||
// not equal to total_tile_k_iter / splits. A common example of this is in stream-K is when a
|
||||
// not equal to total_k_tiles / splits. A common example of this is in stream-K is when a
|
||||
// unit computes the final 20 of the total 32 k iterations of the output tile. In this case,
|
||||
// set splits = 32 and the split index (K_idx) to 11. The zipped divide above results in each
|
||||
// of the splits computing only one k iteration.
|
||||
@ -728,12 +688,13 @@ private:
|
||||
// Returns the number of stream-K tiles that will be computed amongst `output_tiles` total
|
||||
// output tiles on a device with `ctas_per_wave` CTAs in each wave.
|
||||
static uint32_t
|
||||
get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave) {
|
||||
get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave, uint32_t k_tiles_per_output_tile) {
|
||||
uint32_t full_waves = static_cast<uint32_t>(output_tiles / ctas_per_wave);
|
||||
uint32_t total_waves = static_cast<uint32_t>((output_tiles + ctas_per_wave - 1) / ctas_per_wave);
|
||||
|
||||
if (full_waves == total_waves) {
|
||||
// No quantization. All tiles will be data-parallel tiles.
|
||||
if (full_waves == total_waves || k_tiles_per_output_tile == 1) {
|
||||
// All tiles will be data-parallel tiles if there is either no quantization
|
||||
// or if there is no work to be split.
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -811,9 +772,12 @@ private:
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(TileShape{}) - 1) /
|
||||
cute::size<2>(TileShape{});
|
||||
|
||||
dim3 grid = get_grid_shape(problem_shape_mnkl, TileShape{}, cluster_shape, {0, sm_count}, args);
|
||||
uint64_t ctas_per_wave = grid.x * grid.y;
|
||||
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave);
|
||||
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile);
|
||||
|
||||
barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups);
|
||||
reduction_workspace_size = get_reduction_workspace_size<ElementAccumulator>(sk_tiles);
|
||||
@ -829,10 +793,10 @@ private:
|
||||
uint32_t blocks_l,
|
||||
ClusterShape cluster_shape,
|
||||
uint32_t splits,
|
||||
uint32_t k_iter_per_tile,
|
||||
uint32_t k_tiles_per_output_tile,
|
||||
void* reduction_workspace) {
|
||||
|
||||
uint32_t big_units = k_iter_per_tile % splits;
|
||||
uint32_t big_units = k_tiles_per_output_tile % splits;
|
||||
|
||||
return {
|
||||
underlying_params.divmod_cluster_shape_major_,
|
||||
@ -845,7 +809,7 @@ private:
|
||||
underlying_params.raster_order_,
|
||||
cluster_shape,
|
||||
splits,
|
||||
k_iter_per_tile,
|
||||
k_tiles_per_output_tile,
|
||||
big_units,
|
||||
reduction_workspace
|
||||
};
|
||||
@ -855,8 +819,12 @@ private:
|
||||
// is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining
|
||||
// iterations) is used to find the next tile in the current work unit.
|
||||
CUTLASS_DEVICE
|
||||
void
|
||||
set_stream_k_work(uint64_t linear_idx, WorkTileInfo& work_tile_info, bool new_unit) const {
|
||||
static void
|
||||
set_stream_k_work(
|
||||
Params const& params,
|
||||
uint64_t linear_idx,
|
||||
WorkTileInfo& work_tile_info,
|
||||
bool new_unit) {
|
||||
// In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K
|
||||
// threadblock individually. For the most part, the set of K iterations corresponding to stream-K
|
||||
// work was divided amongst stream-K threadblocks, and a threadblock determined which tile
|
||||
@ -872,15 +840,15 @@ private:
|
||||
//
|
||||
// To do so, we divide up the linearized stream-K units into clusters and share the same K
|
||||
// offsets for work within clusters.
|
||||
auto cluster_linear_work_idx = linear_idx / size(scheduler_params.cluster_shape_);
|
||||
auto cluster_linear_work_idx = linear_idx / size(params.cluster_shape_);
|
||||
|
||||
// Determine the starting k iteration computed by this stream-K work unit
|
||||
uint32_t unit_iter_start = scheduler_params.k_iter_per_sk_unit_ * cluster_linear_work_idx;
|
||||
uint32_t unit_iter_start = params.k_tiles_per_sk_unit_ * cluster_linear_work_idx;
|
||||
|
||||
// Adjust the starting position and number of k iterations for "big units," which
|
||||
// compute one extra iteration. These are the first big_units_ units in the
|
||||
// linearized ID space.
|
||||
bool is_big_unit = cluster_linear_work_idx < scheduler_params.big_units_;
|
||||
bool is_big_unit = cluster_linear_work_idx < params.big_units_;
|
||||
if (is_big_unit) {
|
||||
// Since the "big units" are the first units in the linearized ID space, each
|
||||
// of the units preceding this big unit computed one extra iteration. Thus,
|
||||
@ -889,16 +857,16 @@ private:
|
||||
unit_iter_start += cluster_linear_work_idx;
|
||||
} else {
|
||||
// Increment by one for each of the big clusters (since all big units precede this unit)
|
||||
unit_iter_start += scheduler_params.big_units_;
|
||||
unit_iter_start += params.big_units_;
|
||||
}
|
||||
|
||||
uint32_t unit_iters;
|
||||
if (new_unit) {
|
||||
unit_iters = scheduler_params.k_iter_per_sk_unit_;
|
||||
unit_iters = params.k_tiles_per_sk_unit_;
|
||||
|
||||
// Only adjust iteration count for big unit if we are initializing this
|
||||
// work unit. For existing work units, the extra iteration for big units
|
||||
// has already been accounted for in k_iter_reamaining
|
||||
// has already been accounted for in k_tiles_reamaining
|
||||
if (is_big_unit) {
|
||||
++unit_iters;
|
||||
}
|
||||
@ -917,22 +885,21 @@ private:
|
||||
// for them to be computed later, so as to reduce the likelihood of blocking
|
||||
// on other work.
|
||||
uint32_t unit_iter_end = unit_iter_start + unit_iters - 1;
|
||||
uint32_t true_tile_id = unit_iter_end / scheduler_params.k_iter_per_tile_;
|
||||
uint32_t true_tile_iter_start = true_tile_id * scheduler_params.k_iter_per_tile_;
|
||||
uint32_t true_tile_iter_end = true_tile_iter_start + scheduler_params.k_iter_per_tile_;
|
||||
uint32_t true_tile_id = unit_iter_end / params.k_tiles_per_output_tile_;
|
||||
uint32_t true_tile_iter_start = true_tile_id * params.k_tiles_per_output_tile_;
|
||||
uint32_t true_tile_iter_end = true_tile_iter_start + params.k_tiles_per_output_tile_;
|
||||
|
||||
// Bring the linearized tile ID back into the space of tiles, rather than clusters
|
||||
true_tile_id *= size(scheduler_params.cluster_shape_);
|
||||
true_tile_id *= size(params.cluster_shape_);
|
||||
|
||||
auto cluster_dim0 = cute::size<0>(scheduler_params.cluster_shape_);
|
||||
auto cluster_dim1 = cute::size<1>(scheduler_params.cluster_shape_);
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||
|
||||
// The final linearized tile ID is in units of the cluster dimension over which we rasterize.
|
||||
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
|
||||
true_tile_id += (blockIdx.y % cluster_dim1) * cluster_dim0;
|
||||
if (params.raster_order_ == RasterOrder::AlongN) {
|
||||
true_tile_id += cta_n_in_cluster * cute::size<0>(params.cluster_shape_);
|
||||
}
|
||||
else {
|
||||
true_tile_id += (blockIdx.x % cluster_dim0) * cluster_dim1;
|
||||
true_tile_id += cta_m_in_cluster * cute::size<1>(params.cluster_shape_);
|
||||
}
|
||||
|
||||
// The unit's starting k iteration in the current tile is either the starting
|
||||
@ -948,19 +915,18 @@ private:
|
||||
uint32_t tile_iters = tile_iter_end - tile_iter_start;
|
||||
|
||||
uint64_t work_idx_l, remainder;
|
||||
scheduler_params.divmod_batch_(work_idx_l, remainder, true_tile_id);
|
||||
params.divmod_batch_(work_idx_l, remainder, true_tile_id);
|
||||
|
||||
uint64_t cta_per_grid_dim, dontcare;
|
||||
scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
|
||||
|
||||
params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
|
||||
|
||||
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
|
||||
cta_per_grid_dim,
|
||||
scheduler_params.divmod_cluster_shape_major_,
|
||||
scheduler_params.divmod_cluster_shape_minor_,
|
||||
scheduler_params.divmod_cluster_blk_major_,
|
||||
scheduler_params.log_swizzle_size_,
|
||||
scheduler_params.raster_order_);
|
||||
params.divmod_cluster_shape_major_,
|
||||
params.divmod_cluster_shape_minor_,
|
||||
params.divmod_cluster_blk_major_,
|
||||
params.log_swizzle_size_,
|
||||
params.raster_order_);
|
||||
|
||||
//
|
||||
// Update the work_tile_info
|
||||
@ -971,11 +937,11 @@ private:
|
||||
work_tile_info.N_idx = work_idx_n;
|
||||
work_tile_info.L_idx = static_cast<int32_t>(work_idx_l);
|
||||
|
||||
// Set the k offset to be the starting k iteration for this tile
|
||||
// Set the k offset to be the starting k tile for this output tile
|
||||
work_tile_info.K_idx = static_cast<int32_t>(tile_iter_start - true_tile_iter_start);
|
||||
|
||||
// Set the split count to be the number of k iterations in the tile
|
||||
work_tile_info.splits = scheduler_params.k_iter_per_tile_;
|
||||
// Set the split count to be the number of k tiles in the output tile
|
||||
work_tile_info.splits = params.k_tiles_per_output_tile_;
|
||||
|
||||
// Any checks for invalid work units should be done prior to this call
|
||||
work_tile_info.is_valid_tile = true;
|
||||
@ -987,6 +953,89 @@ private:
|
||||
// the output tile in question
|
||||
work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end);
|
||||
}
|
||||
|
||||
// Returns a WorkTileInfo to be computed for either the data-parallel or split-K
|
||||
// work unit identified by the provided linear ID.
|
||||
CUTLASS_DEVICE
|
||||
static WorkTileInfo
|
||||
set_non_stream_k_work(uint64_t linear_idx, Params const& params, bool is_split_k) {
|
||||
|
||||
// The linearized ID space is in terms of work units, rather than tiles. However,
|
||||
// to compute the correct block offset for a data-parallel tile, we must convert
|
||||
// the current ID to the data-parallel tile it corresponds to. Each data-parallel
|
||||
// unit maps to a single data-parallel tile, but each stream-K unit can map to more
|
||||
// than one tile. Thus, we must offset the work-unit ID among the data-parallel units
|
||||
// by the total number of output tiles that will be computed by stream-K units.
|
||||
//
|
||||
// The logic below also works for the split-K case, in which sk_units_ and sk_tiles_
|
||||
// are each 0.
|
||||
uint64_t linear_work_idx = linear_idx - params.sk_units_ + params.sk_tiles_;
|
||||
|
||||
// Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices
|
||||
uint64_t work_idx_l, remainder;
|
||||
params.divmod_batch_(work_idx_l, remainder, linear_work_idx);
|
||||
|
||||
uint64_t work_idx_k = 0;
|
||||
if (is_split_k) {
|
||||
params.divmod_k_(work_idx_k, remainder, remainder);
|
||||
}
|
||||
|
||||
uint64_t cta_per_grid_dim, dontcare;
|
||||
params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
|
||||
|
||||
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
|
||||
cta_per_grid_dim,
|
||||
params.divmod_cluster_shape_major_,
|
||||
params.divmod_cluster_shape_minor_,
|
||||
params.divmod_cluster_blk_major_,
|
||||
params.log_swizzle_size_,
|
||||
params.raster_order_);
|
||||
|
||||
bool is_final_split = (work_idx_k == params.splits_ - 1);
|
||||
|
||||
uint32_t k_tiles = params.k_tiles_per_output_tile_;
|
||||
if (is_split_k) {
|
||||
// Determine the number of iterations and starting iteration of this split.
|
||||
// Doing so requires accounting for residual iterations, which are handled
|
||||
// by the first big_units_ splits (with big_units_ = tiles % sm_count).
|
||||
|
||||
// Offsets for "normal" units. No additional k iterations are performed,
|
||||
// and big_units_ "big" units preceded us, each of which performed one
|
||||
// additional iteration. Thus, we must increase our split starting offset
|
||||
// by big_units_.
|
||||
int additional_k_tiles = 0;
|
||||
int split_start_offset = params.big_units_;
|
||||
|
||||
if (work_idx_k < params.big_units_) {
|
||||
// Offsets for "big" units. One additional k iteration is performed,
|
||||
// and each split preceding us was a big unit, so we must increase
|
||||
// our split starting offset by our split ID (work_idx_k).
|
||||
additional_k_tiles = 1;
|
||||
split_start_offset = work_idx_k;
|
||||
}
|
||||
|
||||
// Set up k iteration count and split starting iteration assuming the
|
||||
// iteration space is evenly split.
|
||||
k_tiles /= params.splits_;
|
||||
work_idx_k *= k_tiles;
|
||||
|
||||
// Apply any fixup needed to handle residuals
|
||||
work_idx_k += split_start_offset;
|
||||
k_tiles += additional_k_tiles;
|
||||
}
|
||||
|
||||
return {
|
||||
work_idx_m,
|
||||
work_idx_n,
|
||||
static_cast<int32_t>(work_idx_k),
|
||||
static_cast<int32_t>(work_idx_l),
|
||||
true,
|
||||
params.k_tiles_per_output_tile_,
|
||||
k_tiles,
|
||||
k_tiles, // remaining iterations
|
||||
is_final_split
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::kernel::detail
|
||||
|
||||
@ -215,10 +215,17 @@ class GemmArguments2x(ArgumentBase):
|
||||
else:
|
||||
self.batch_count = 1
|
||||
|
||||
if "batch_strides" in kwargs:
|
||||
self.batched_stride_A = kwargs["batch_strides"]["A"]
|
||||
self.batched_stride_B = kwargs["batch_strides"]["B"]
|
||||
self.batched_stride_C = kwargs["batch_strides"]["C"]
|
||||
self.batched_stride_D = kwargs["batch_strides"]["D"]
|
||||
else:
|
||||
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
|
||||
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
|
||||
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
|
||||
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
|
||||
|
||||
if self.bias:
|
||||
self.batched_stride_C = self.problem_size.n()
|
||||
|
||||
|
||||
@ -132,9 +132,9 @@ class KernelsForDataType:
|
||||
"""
|
||||
# Determine the leading dimension of the shape
|
||||
if layout == cutlass.LayoutType.ColumnMajor:
|
||||
ld = shape[0]
|
||||
ld = shape[-2]
|
||||
elif layout == cutlass.LayoutType.RowMajor:
|
||||
ld = shape[1]
|
||||
ld = shape[-1]
|
||||
elif layout == cutlass.LayoutType.TensorNHWC:
|
||||
ld = shape[-1]
|
||||
else:
|
||||
|
||||
@ -114,6 +114,8 @@
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
from math import prod
|
||||
|
||||
import cutlass_bindings
|
||||
|
||||
import cutlass
|
||||
@ -442,6 +444,113 @@ class Gemm(OperationBase):
|
||||
compiler.add_module([self.operation,])
|
||||
return self.operation
|
||||
|
||||
def _verify_rank(self, tensor):
|
||||
"""
|
||||
Verifies that ``tensor`` has rank greater than 1
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
"""
|
||||
if len(tensor.shape) < 2:
|
||||
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
|
||||
|
||||
def _get_batch_count(self, A, B, C, D) -> int:
|
||||
"""
|
||||
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
|
||||
tensors match in batch size. Presence of a batch dimension is detected by one of the
|
||||
tensors being rank 3. If a batch dimension is present, it must be present in one of
|
||||
operands A, B, or C (but need not be in all), and must be present in D.
|
||||
|
||||
:param A: tensor A
|
||||
:type A: numpy/cupy/torch array/tensor object
|
||||
:param B: tensor B
|
||||
:type B: numpy/cupy/torch array/tensor object
|
||||
:param C: tensor C
|
||||
:type C: numpy/cupy/torch array/tensor object
|
||||
:param D: tensor D
|
||||
:type D: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: tuple of batch count dimensions
|
||||
:rtype: tuple
|
||||
"""
|
||||
A_batch = A.shape[:-2] if len(A.shape) > 2 else tuple()
|
||||
B_batch = B.shape[:-2] if len(B.shape) > 2 else tuple()
|
||||
C_batch = C.shape[:-2] if len(C.shape) > 2 else tuple()
|
||||
D_batch = D.shape[:-2] if len(D.shape) > 2 else tuple()
|
||||
|
||||
if len(D_batch) > 0 and D_batch not in [A_batch, B_batch, C_batch]:
|
||||
raise Exception(f"Batch count in D must be present in one of operands A, B, and C. "
|
||||
f"Batch counts are: A={A_batch}, B={B_batch}, C={C_batch}, D={D_batch}")
|
||||
|
||||
for batch_shape in [A_batch, B_batch, C_batch]:
|
||||
if len(batch_shape) > 0 and batch_shape != D_batch:
|
||||
raise Exception(f"Batch count for all other operands must either match that of D or be zero."
|
||||
f"Received batch shape of {batch_shape}, which does not match that of D of {D_batch}.")
|
||||
|
||||
return D_batch
|
||||
|
||||
def _get_batch_stride(self, tensor) -> int:
|
||||
"""
|
||||
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
|
||||
|
||||
:param tensor: tensor object to process
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: stride between each matrix in the batch
|
||||
:rtype: int
|
||||
"""
|
||||
if len(tensor.shape) > 2:
|
||||
return tensor.shape[-2] * tensor.shape[-1]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _get_problem_args(self, A, B, C, D) -> tuple:
|
||||
"""
|
||||
Returns the problem size and GEMM universal mode to use for the
|
||||
given operands.
|
||||
|
||||
:param A: tensor A
|
||||
:type A: numpy/cupy/torch array/tensor object
|
||||
:param B: tensor B
|
||||
:type B: numpy/cupy/torch array/tensor object
|
||||
:param C: tensor C
|
||||
:type C: numpy/cupy/torch array/tensor object
|
||||
:param D: tensor D
|
||||
:type D: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: tuple containing the problem size (cutlass_bindings.gemm.GemmCoord), the GEMM mode (cutlass_bindings.gemm.Mode), and the batch count (int)
|
||||
:rtype: tuple
|
||||
"""
|
||||
M, K = A.shape[-2:]
|
||||
N = B.shape[-1]
|
||||
mode = cutlass_bindings.gemm.Mode.Gemm
|
||||
|
||||
batch_count = self._get_batch_count(A, B, C, D)
|
||||
returned_batch_count = prod(batch_count) if len(batch_count) > 0 else 1
|
||||
|
||||
# If we are running a batched GEMM in which there is a nonzero batch stride
|
||||
# only for A, then we can fold the batched dimension of A into the M dimension
|
||||
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
|
||||
# and C are row major. A similar operation can be performed if only B has a nonzero
|
||||
# batch dimension
|
||||
if len(batch_count) > 0:
|
||||
A_row = self._layout_a == cutlass.LayoutType.RowMajor
|
||||
B_row = self._layout_b == cutlass.LayoutType.RowMajor
|
||||
C_row = self._layout_c == cutlass.LayoutType.RowMajor
|
||||
|
||||
batched = lambda x : len(x.shape) == 2 + len(batch_count)
|
||||
|
||||
if batched(A) and not batched(B) and batched(C) and A_row and C_row:
|
||||
M *= prod(batch_count)
|
||||
returned_batch_count = 1
|
||||
elif not batched(A) and batched(B) and batched(C) and not B_row and not C_row:
|
||||
N *= prod(batch_count)
|
||||
returned_batch_count = 1
|
||||
else:
|
||||
mode = cutlass_bindings.gemm.Mode.Batched
|
||||
|
||||
return cutlass_bindings.gemm.GemmCoord(M, N, K), mode, returned_batch_count
|
||||
|
||||
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
||||
"""
|
||||
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
||||
@ -461,8 +570,7 @@ class Gemm(OperationBase):
|
||||
f'layout of ({ref_type}, {ref_layout}).')
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
alpha=None, beta=None, batch_count: int = 1,
|
||||
sync: bool = True, print_module: bool = False) -> GemmArguments:
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False) -> GemmArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
@ -481,8 +589,6 @@ class Gemm(OperationBase):
|
||||
:param D: tensor representing data type and layout of operand D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param batch_count: number of GEMMs in the batch
|
||||
:type batch_count: int
|
||||
:param sync: whether the call should wait for the kernel to complete before returning
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
@ -491,9 +597,6 @@ class Gemm(OperationBase):
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.GemmArguments
|
||||
"""
|
||||
if batch_count < 1:
|
||||
raise Exception(f"Invalid batch count {batch_count}. Value must be an integer >= 1.")
|
||||
|
||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
||||
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
||||
@ -501,20 +604,31 @@ class Gemm(OperationBase):
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
self._verify_rank(A)
|
||||
self._verify_rank(B)
|
||||
self._verify_rank(C)
|
||||
self._verify_rank(D)
|
||||
|
||||
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a)
|
||||
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b)
|
||||
alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c)
|
||||
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, print_module=print_module)
|
||||
|
||||
problem_size = cutlass_bindings.gemm.GemmCoord(A.shape[0], B.shape[1], A.shape[1])
|
||||
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
|
||||
|
||||
if batch_count == 1:
|
||||
mode = cutlass_bindings.gemm.Mode.Gemm
|
||||
if mode == cutlass_bindings.gemm.Mode.Gemm or batch_count == 1:
|
||||
kwargs = {'split_k_slices': 1}
|
||||
else:
|
||||
mode = cutlass_bindings.gemm.Mode.Batched
|
||||
kwargs = {'batch': batch_count}
|
||||
kwargs = {
|
||||
'batch': batch_count,
|
||||
'batch_strides': {
|
||||
'A': self._get_batch_stride(A),
|
||||
'B': self._get_batch_stride(B),
|
||||
'C': self._get_batch_stride(C),
|
||||
'D': self._get_batch_stride(D)
|
||||
}
|
||||
}
|
||||
|
||||
arguments = GemmArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
|
||||
139
test/python/gemm/gemm_batched.py
Normal file
139
test/python/gemm/gemm_batched.py
Normal file
@ -0,0 +1,139 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
High-level tests for running batched GEMMs
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from math import prod
|
||||
|
||||
import cutlass
|
||||
import logging
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from cutlass.backend.test.utils import LayoutCombination, add_test_gemm
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
|
||||
cutlass.set_log_level(logging.WARNING)
|
||||
|
||||
torch.manual_seed(2023)
|
||||
|
||||
|
||||
def pytorch_reference(A, B, C, alpha, beta):
|
||||
# Get the batch count. Assume that any of A, B, and C
|
||||
# with a batch dimension ahve matching batch count. Thus,
|
||||
# we break out of the loop once we have found the first
|
||||
# tensor containing a batch dimension.
|
||||
batch_count = (1,)
|
||||
for tensor in [A, B, C]:
|
||||
if len(tensor.shape) > 2:
|
||||
batch_count = tensor.shape[:-2]
|
||||
break
|
||||
|
||||
int_batch_count = prod(batch_count)
|
||||
|
||||
def add_batch(tensor):
|
||||
if len(tensor.shape) == 2:
|
||||
return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1)
|
||||
else:
|
||||
return tensor.reshape(-1, tensor.size(-2), tensor.size(-1))
|
||||
|
||||
# Reshape tensors to have batch dimension
|
||||
A = add_batch(A)
|
||||
B = add_batch(B)
|
||||
C = add_batch(C)
|
||||
|
||||
ret = (torch.bmm(A, B) * alpha) + (C * beta)
|
||||
reshape_vals = batch_count + C.shape[-2:]
|
||||
return ret.reshape(*reshape_vals)
|
||||
|
||||
|
||||
def initialize(rows, cols, batch):
|
||||
tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half()
|
||||
if len(batch) > 0 and prod(batch) > 1:
|
||||
reshape_vals = batch + (rows, cols)
|
||||
return tensor.reshape(*reshape_vals)
|
||||
else:
|
||||
return tensor.reshape(rows, cols)
|
||||
|
||||
|
||||
class GemmF16Batched(unittest.TestCase):
|
||||
def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool):
|
||||
M = 512
|
||||
N = 256
|
||||
K = 128
|
||||
alpha = 1.
|
||||
beta = 2.
|
||||
|
||||
A = initialize(M, K, batch_count if batch_A else (1,))
|
||||
B = initialize(K, N, batch_count if batch_B else (1,))
|
||||
C = initialize(M, N, batch_count if batch_C else (1,))
|
||||
D = initialize(M, N, batch_count)
|
||||
|
||||
plan = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass.DataType.f32)
|
||||
plan.run(A, B, C, D, alpha, beta)
|
||||
reference = pytorch_reference(A, B, C, alpha, beta)
|
||||
assert reference.equal(D)
|
||||
|
||||
def test_batched_ABC(self):
|
||||
self.run_batched((3,), True, True, True)
|
||||
self.run_batched((2, 3), True, True, True)
|
||||
|
||||
def test_batched_AB(self):
|
||||
self.run_batched((3,), True, True, False)
|
||||
self.run_batched((2, 3), True, True, False)
|
||||
|
||||
def test_batched_AC(self):
|
||||
self.run_batched((3,), True, False, True)
|
||||
self.run_batched((2, 3), True, False, True)
|
||||
|
||||
def test_batched_BC(self):
|
||||
self.run_batched((3,), False, True, True)
|
||||
self.run_batched((2, 3), False, True, True)
|
||||
|
||||
def test_batched_A(self):
|
||||
self.run_batched((3,), True, False, False)
|
||||
self.run_batched((2, 3), True, False, False)
|
||||
|
||||
def test_batched_B(self):
|
||||
self.run_batched((3,), False, True, False)
|
||||
self.run_batched((2, 3), False, True, False)
|
||||
|
||||
def test_batched_C(self):
|
||||
self.run_batched((3,), False, False, True)
|
||||
self.run_batched((2, 3), False, False, True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -61,7 +61,7 @@ __global__ void convert(
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Destination, typename Source, int Count>
|
||||
void run_test() {
|
||||
void run_test(const char dest_name[], const char source_name[]) {
|
||||
const int kN = Count;
|
||||
|
||||
dim3 grid(1, 1);
|
||||
@ -84,7 +84,10 @@ void run_test() {
|
||||
destination.sync_host();
|
||||
|
||||
for (int i = 0; i < kN; ++i) {
|
||||
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
|
||||
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]))
|
||||
<< "Destination type: " << dest_name
|
||||
<< ", Source type: " << source_name
|
||||
<< ", Count: " << Count;
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,15 +100,19 @@ void run_test() {
|
||||
TEST(NumericConversion, f32_to_f16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32x8_to_f16x8_rn) {
|
||||
int const kN = 8;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -113,15 +120,19 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) {
|
||||
TEST(NumericConversion, f16_to_f32_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16x8_to_f32x8_rn) {
|
||||
int const kN = 8;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -129,86 +140,109 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) {
|
||||
TEST(NumericConversion, f32_to_fe4m3_rn) {
|
||||
int const kN = 1;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32_to_fe4m3_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32_to_fe5m2_rn) {
|
||||
int const kN = 1;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f32_to_fe5m2_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16_to_fe4m3_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16_to_fe4m3_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16_to_fe5m2_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, f16_to_fe5m2_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::half_t;
|
||||
const char source_name[] = "half_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, bf16_to_fe4m3_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::bfloat16_t;
|
||||
const char source_name[] = "bfloat16_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, bf16_to_fe4m3_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::bfloat16_t;
|
||||
const char source_name[] = "bfloat16_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, bf16_to_fe5m2_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::bfloat16_t;
|
||||
const char source_name[] = "bfloat16_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::bfloat16_t;
|
||||
const char source_name[] = "bfloat16_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -216,36 +250,46 @@ TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
|
||||
TEST(NumericConversion, fe4m3_to_fe5m2_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_fe5m2_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::float_e5m2_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e5m2_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_fe4m3_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_fe4m3_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::float_e4m3_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float_e4m3_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_f32_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -254,78 +298,100 @@ TEST(NumericConversion, f32x8_to_s8x8_rn) {
|
||||
|
||||
int const kN = 8;
|
||||
using Source = float;
|
||||
const char source_name[] = "float";
|
||||
using Destination = int8_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "int8_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_f32_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_f32_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = float;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "float";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_f16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_f16_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_f16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_f16_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::half_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "half_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_bf16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "bfloat16_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe4m3_to_bf16_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e4m3_t;
|
||||
const char source_name[] = "float_e4m3_t";
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "bfloat16_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_bf16_rn) {
|
||||
int const kN = 1;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "bfloat16_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
TEST(NumericConversion, fe5m2_to_bf16_array) {
|
||||
int const kN = 27;
|
||||
using Source = cutlass::float_e5m2_t;
|
||||
const char source_name[] = "float_e5m2_t";
|
||||
using Destination = cutlass::bfloat16_t;
|
||||
test::core::kernel::run_test<Destination, Source, kN>();
|
||||
const char dest_name[] = "bfloat16_t";
|
||||
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -36,6 +36,7 @@ cutlass_test_unit_add_executable(
|
||||
compare.cpp
|
||||
complement.cpp
|
||||
composition.cpp
|
||||
constant_arithmetic.cpp
|
||||
core_unit.cpp
|
||||
inverse_left.cpp
|
||||
inverse_right.cpp
|
||||
|
||||
106
test/unit/cute/core/constant_arithmetic.cpp
Normal file
106
test/unit/cute/core/constant_arithmetic.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
#include <cutlass/trace.h>
|
||||
#include <cute/swizzle.hpp>
|
||||
|
||||
TEST(CuTe_core, ConstantArithmetic) {
|
||||
using namespace cute;
|
||||
|
||||
constexpr cute::integral_constant<uint32_t, 0> uzero{};
|
||||
|
||||
// This extra test exists historically as part of the diagnosis
|
||||
// of a possible Clang 14 bug. However, it's a nice test for
|
||||
// cute::integral_constant's arithmetic operators, so it's saved here.
|
||||
// It also demonstrates how to work with cute::integral_constant
|
||||
// and lambda captures. Microsoft Visual Studio ("MSVC") tends to
|
||||
// disagree with other compilers about the meaning of decltype
|
||||
// for variables captured by reference. MSVC and GCC 8.3.0
|
||||
// also tend to disagree with other compilers (and other GCC versions)
|
||||
// about whether expressions involving such variables
|
||||
// are constant expressions.
|
||||
//
|
||||
// A typical CuTe idiom is to do lambda captures by reference [&].
|
||||
// This test changes them to capture by value, except for
|
||||
// the innermost lambda's capture of S1, which is by reference.
|
||||
// The point is to show that MSVC and GCC 8 have issues with this
|
||||
// that other compilers do not. For example,
|
||||
//
|
||||
// 1. MSVC needs remove_cvref_t around decltype(S1)
|
||||
// in order to access decltype(S1)::value, and
|
||||
// 2. MSVC and GCC 8.3.0 both report a build error with S1()
|
||||
// (that is, calling operator() on S1, which returns the
|
||||
// same thing as S1.value).
|
||||
//
|
||||
// The reason for (2) is that neither compiler thinks
|
||||
// that S1() is a constant expression.
|
||||
//
|
||||
// This leaves S1.value as the most concise portable expression
|
||||
// for the "value" member of a cute::integral_constant.
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero](auto S0) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0](auto F0) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0](auto S1) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0,&S1](auto F1) {
|
||||
static_assert((decltype(S0)::value & decltype(F0)::value) == decltype(S0 & F0)::value);
|
||||
|
||||
// Using S1.value means you don't have to use remove_cvref_t
|
||||
// with a captured-by-reference variable.
|
||||
static_assert((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||
static_assert((S1.value & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||
// S1() _should_ work, but does not with Visual Studio 2022,
|
||||
// which emits C2131 ("expression did not evaluate to a constant").
|
||||
// It also does not with GCC 8.3.0, which emits an error with messages
|
||||
// "non-constant condition for static assertion" and
|
||||
// "'this' is not a constant expression."
|
||||
//
|
||||
//static_assert((S1() & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||
static_assert(decltype((S0 & F0) != uzero)::value == ((decltype(S0)::value & decltype(F0)::value) != 0));
|
||||
|
||||
static_assert(decltype((S1 & F1) != uzero)::value == ((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0));
|
||||
static_assert(decltype((S1 & F1) != uzero)::value == ((S1.value & decltype(F1)::value) != 0));
|
||||
|
||||
constexpr bool left = decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value;
|
||||
constexpr bool right =
|
||||
((decltype(S0)::value & decltype(F0)::value) != 0) ||
|
||||
((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0);
|
||||
constexpr bool right2 =
|
||||
((decltype(S0)::value & decltype(F0)::value) != 0) ||
|
||||
((S1.value & decltype(F1)::value) != 0);
|
||||
static_assert(right == right2);
|
||||
static_assert(left == right);
|
||||
constexpr bool left2 = decltype((S0 & F0) != uzero)::value || decltype((S1 & F1) != uzero)::value;
|
||||
static_assert(left == left2);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -31,23 +31,41 @@
|
||||
|
||||
#include "cutlass_unit_test.h"
|
||||
|
||||
// C<uint32_t(something)>::value_type is not uint32_t for GCC 7.5.0.
|
||||
// This test is thus disabled for GCC < 8.
|
||||
#if defined(__GNUC__) && (__GNUC__ < 8)
|
||||
|
||||
#include <cutlass/trace.h>
|
||||
#include <cute/swizzle.hpp>
|
||||
|
||||
TEST(CuTe_core, MixedBits) {
|
||||
using namespace cute;
|
||||
namespace { // (anonymous)
|
||||
|
||||
auto uzero = cute::integral_constant<uint32_t, 0>{};
|
||||
// This function exists to work around a Clang 14 issue, in which
|
||||
// the compiler tries to instantiate code that lives inside the
|
||||
// "else" branch of an "if constexpr," even when the "else" branch
|
||||
// is false. That triggers a spurious static_assert in MixedBits.
|
||||
// The work-around is to make the body of the "else" branch a
|
||||
// function, rather than leaving it in line.
|
||||
//
|
||||
// Some compilers strangely deduce the first two terms of
|
||||
// make_integer_sequence<uint32_t, 8> as C<false> and C<true>, and
|
||||
// the remaining terms as C<2>, C<3>, etc. Making this function take
|
||||
// cute::integral_constant<uint32_t, S0_value>, etc. doesn't work
|
||||
// with those compilers.
|
||||
template<class S0_type, S0_type S0_value,
|
||||
class F0_type, F0_type F0_value,
|
||||
class S1_type, S1_type S1_value,
|
||||
class F1_type, F1_type F1_value>
|
||||
void clang14_workaround(cute::integral_constant<S0_type, S0_value>,
|
||||
cute::integral_constant<F0_type, F0_value>,
|
||||
cute::integral_constant<S1_type, S1_value>,
|
||||
cute::integral_constant<F1_type, F1_value>)
|
||||
{
|
||||
constexpr cute::C<static_cast<uint32_t>(S0_value)> S0{};
|
||||
constexpr cute::C<static_cast<uint32_t>(F0_value)> F0{};
|
||||
constexpr cute::C<static_cast<uint32_t>(S1_value)> S1{};
|
||||
constexpr cute::C<static_cast<uint32_t>(F1_value)> F1{};
|
||||
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [&](auto S0) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [&](auto F0) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [&](auto S1) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [&](auto F1) {
|
||||
if constexpr (decltype(S0 == uzero || S1 == uzero)::value) {
|
||||
return;
|
||||
} else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) {
|
||||
return;
|
||||
} else {
|
||||
for (uint32_t d0 = 0; d0 < 8; ++d0) {
|
||||
if ((d0 & F0) != d0) { continue; } // Skip repeats
|
||||
for (uint32_t d1 = 0; d1 < 8; ++d1) {
|
||||
@ -63,8 +81,37 @@ TEST(CuTe_core, MixedBits) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace (anonymous)
|
||||
|
||||
TEST(CuTe_core, MixedBits) {
|
||||
using namespace cute;
|
||||
|
||||
auto uzero = cute::integral_constant<uint32_t, 0>{};
|
||||
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [&](auto S0) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [&](auto F0) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [&](auto S1) {
|
||||
for_each(make_integer_sequence<uint32_t, 8>{}, [&](auto F1) {
|
||||
if constexpr (decltype(S0 == uzero || S1 == uzero)::value) {
|
||||
return;
|
||||
} else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) {
|
||||
return;
|
||||
} else {
|
||||
clang14_workaround(S0, F0, S1, F1);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
TEST(CuTe_core, MakeIntegerSequence) {
|
||||
cute::for_each(cute::make_integer_sequence<uint32_t, 8>{}, [](auto c) {
|
||||
using c_type = decltype(c);
|
||||
constexpr auto c_value = c_type::value;
|
||||
using expected_type = cute::integral_constant<uint32_t, c_value>;
|
||||
static_assert(cute::is_same_v<c_type, expected_type>);
|
||||
});
|
||||
}
|
||||
|
||||
#endif // defined(__GNUC__) && (__GNUC__ < 8)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
\brief Tests that the stream-K scheduler covers the entire problem space.
|
||||
*/
|
||||
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
@ -39,6 +40,10 @@
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
// Grids are launched with clusters enabled in these tests,
|
||||
// so the CTK version must support cluster launching.
|
||||
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
|
||||
using namespace cute;
|
||||
using ProblemShape_MNKL = Shape<int, int, int, int>;
|
||||
|
||||
@ -60,7 +65,7 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape
|
||||
while (work_tile_info.is_valid_tile) {
|
||||
// Increment counters to indicate coverage
|
||||
auto tile_idx = Scheduler::output_tile_index(params, work_tile_info);
|
||||
auto offset = tile_idx * params.k_iter_per_tile_ + work_tile_info.K_idx;
|
||||
auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx;
|
||||
for (auto i = 0; i < work_tile_info.k_tile_count; ++i) {
|
||||
// Use atomicAdd because the visit counters are shared by multiple thread blocks.
|
||||
// While having more than one block increment the same counter indicates failure,
|
||||
@ -103,7 +108,7 @@ test_scheduler(
|
||||
|
||||
// Allocate counters indicating the number of times each k iteration of each output tile has been visited
|
||||
auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
|
||||
auto total_counters = blk_m * blk_n * blk_l * params.k_iter_per_tile_;
|
||||
auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_;
|
||||
cutlass::DeviceAllocation<int> visit_counters(total_counters);
|
||||
|
||||
// Initialize counters to zero
|
||||
@ -118,12 +123,55 @@ test_scheduler(
|
||||
// Set up the grid for the problem
|
||||
dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args);
|
||||
|
||||
// Set up cluster and cluster launch. This is needed even for this simple kernel because
|
||||
// the SM90 scheduler needs to be able to query the CTA id within a cluster, which requires
|
||||
// explicitly launching with clusters.
|
||||
dim3 cluster{
|
||||
static_cast<uint32_t>(cute::get<0>(ClusterShape{})),
|
||||
static_cast<uint32_t>(cute::get<1>(ClusterShape{})),
|
||||
static_cast<uint32_t>(cute::get<2>(ClusterShape{}))
|
||||
};
|
||||
|
||||
cudaLaunchConfig_t launch_config;
|
||||
launch_config.gridDim = grid;
|
||||
launch_config.blockDim = {1, 1, 1};
|
||||
launch_config.dynamicSmemBytes = 0;
|
||||
launch_config.stream = NULL;
|
||||
|
||||
cudaLaunchAttribute launch_attribute[1];
|
||||
launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
|
||||
launch_attribute[0].val.clusterDim.x = cluster.x;
|
||||
launch_attribute[0].val.clusterDim.y = cluster.y;
|
||||
launch_attribute[0].val.clusterDim.z = cluster.z;
|
||||
|
||||
launch_config.attrs = launch_attribute;
|
||||
launch_config.numAttrs = 1;
|
||||
|
||||
void const* kernel = (void const*) run_scheduler<Scheduler, TileShape, ClusterShape>;
|
||||
int* counters_ptr = visit_counters.get();
|
||||
void* kernel_params[] = {
|
||||
&counters_ptr,
|
||||
¶ms,
|
||||
&tile_shape,
|
||||
&cluster_shape,
|
||||
&problem_shape_mnkl
|
||||
};
|
||||
|
||||
// Run the scheduler to completion and log visits to each k iteration
|
||||
run_scheduler<Scheduler, TileShape, ClusterShape><<<grid, 1>>>(
|
||||
visit_counters.get(), params, tile_shape, cluster_shape, problem_shape_mnkl);
|
||||
err = cudaLaunchKernelExC(&launch_config, kernel, kernel_params);
|
||||
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << __FILE__ << ":" << __LINE__
|
||||
<< " cudaLaunchKernelExC failed with error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
std::cerr << __FILE__ << ":" << __LINE__ << " scheduler kernel failed with error: " << cudaGetErrorString(err) << std::endl;
|
||||
std::cerr << __FILE__ << ":" << __LINE__
|
||||
<< " scheduler kernel failed with error: "
|
||||
<< cudaGetErrorString(err) << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -143,11 +191,11 @@ test_scheduler(
|
||||
<< " and grid size " << grid.x << "x"
|
||||
<< grid.y << "x" << grid.z
|
||||
<< " splits=" << params.splits_
|
||||
<< " k_iter=" << params.k_iter_per_tile_
|
||||
<< " k_iter=" << params.k_tiles_per_output_tile_
|
||||
<< " big_units=" << params.big_units_
|
||||
<< " sk_tiles=" << params.sk_tiles_
|
||||
<< " sk_units=" << params.sk_units_
|
||||
<< " k_iter_per_sk_unit=" << params.k_iter_per_sk_unit_ << std::endl;
|
||||
<< " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl;
|
||||
std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl;
|
||||
return false;
|
||||
}
|
||||
@ -274,4 +322,6 @@ TEST(SM90_Device_Gemm_stream_k_scheduler, 128x128x64_2x1x1) {
|
||||
EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114));
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -179,7 +179,7 @@ class GemmOperation:
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
if self.arch >= 90:
|
||||
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}{t}"
|
||||
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{t}{k}{e}"
|
||||
return kernel_name_template.format(
|
||||
p = self.prefix,
|
||||
ar = self.arch,
|
||||
@ -194,9 +194,9 @@ class GemmOperation:
|
||||
l = self.tile_description.stages,
|
||||
s = self.layout_name_3x(),
|
||||
al = str(max(self.A.alignment, self.B.alignment)),
|
||||
t = TileSchedulerSuffixes[self.tile_scheduler],
|
||||
k = self.kernel_schedule_name_3x(),
|
||||
e = self.epilogue_schedule_name_3x(),
|
||||
t = TileSchedulerSuffixes[self.tile_scheduler])
|
||||
e = self.epilogue_schedule_name_3x())
|
||||
else:
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
|
||||
@ -661,8 +661,7 @@ using ${operation_name}_mainloop =
|
||||
${element_accumulator},
|
||||
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
||||
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
sizeof(typename ${operation_name}_epilogue::SharedStorage)>,
|
||||
${stages},
|
||||
${kernel_schedule}
|
||||
>::CollectiveOp;
|
||||
|
||||
@ -697,7 +696,7 @@ ${compile_guard_end}
|
||||
if operation.tile_description.stages > 0:
|
||||
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
|
||||
else:
|
||||
stage_count_string = "cutlass::gemm::collective::StageCountAuto"
|
||||
stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage)>"
|
||||
warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)]
|
||||
|
||||
instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \
|
||||
|
||||
@ -4218,6 +4218,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
layout[2][1] = 8
|
||||
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
|
||||
|
||||
# persistent kernels with TMA epilogues
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user