Updates for 3.2 release (#1065)
This commit is contained in:
parent
27de343535
commit
a88c41cf8d
@ -2,10 +2,14 @@
|
|||||||
|
|
||||||
## 2023
|
## 2023
|
||||||
|
|
||||||
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, and Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
|
- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023.
|
||||||
|
|
||||||
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
|
- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023.
|
||||||
|
|
||||||
|
- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023.
|
||||||
|
|
||||||
|
- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023.
|
||||||
|
|
||||||
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
|
- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023.
|
||||||
|
|
||||||
## 2022
|
## 2022
|
||||||
|
|||||||
@ -71,7 +71,7 @@ struct Options {
|
|||||||
/// Prints the usage statement.
|
/// Prints the usage statement.
|
||||||
std::ostream & print_usage(std::ostream &out) const {
|
std::ostream & print_usage(std::ostream &out) const {
|
||||||
|
|
||||||
out << "52_fp8_hopper_warp_specialized_gemm\n\n"
|
out << "54_fp8_hopper_warp_specialized_gemm\n\n"
|
||||||
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
|
<< " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n"
|
||||||
<< "Options:\n\n"
|
<< "Options:\n\n"
|
||||||
<< " --help If specified, displays this usage statement\n\n"
|
<< " --help If specified, displays this usage statement\n\n"
|
||||||
@ -93,7 +93,7 @@ struct Options {
|
|||||||
|
|
||||||
out
|
out
|
||||||
<< "\n\nExamples:\n\n"
|
<< "\n\nExamples:\n\n"
|
||||||
<< "$ " << "52_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
<< "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -54,6 +54,13 @@ struct SyncthreadsSync {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct SyncwarpSync {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void sync() {
|
||||||
|
__syncwarp();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <
|
template <
|
||||||
int ThreadCount,
|
int ThreadCount,
|
||||||
int BarrierId
|
int BarrierId
|
||||||
@ -311,6 +318,60 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads)
|
||||||
|
* via an API that mirrors that of NamedBarrierManager
|
||||||
|
*
|
||||||
|
* @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization
|
||||||
|
**/
|
||||||
|
template <
|
||||||
|
class Synchronizer,
|
||||||
|
uint32_t ThreadCount_
|
||||||
|
>
|
||||||
|
struct SyncManager {
|
||||||
|
|
||||||
|
// Number of threads participating in the barrier
|
||||||
|
static constexpr uint32_t ThreadCount = ThreadCount_;
|
||||||
|
|
||||||
|
using BarrierSync = cutlass::GenericBarrier<Synchronizer>;
|
||||||
|
|
||||||
|
// Underlying type used by all barriers for synchronization.
|
||||||
|
using T = typename BarrierSync::T;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static
|
||||||
|
void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) {
|
||||||
|
BarrierSync::wait_lt_helper(lock_ptr, thread_idx, flag_idx, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void
|
||||||
|
wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
||||||
|
BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void
|
||||||
|
wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) {
|
||||||
|
BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void
|
||||||
|
arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) {
|
||||||
|
BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void
|
||||||
|
arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) {
|
||||||
|
BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -67,7 +67,7 @@ sm90_get_tma_dispatch_policy() {
|
|||||||
|
|
||||||
constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{}));
|
constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{}));
|
||||||
constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v<Schedule> ? 256 : 128);
|
constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v<Schedule> ? 256 : 128);
|
||||||
constexpr int ReuseSmemC = sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>;
|
constexpr int ReuseSmemC = (sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>) && (sizeof_bits_v<ElementD> > 8);
|
||||||
constexpr int StagesD = 2;
|
constexpr int StagesD = 2;
|
||||||
constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles;
|
constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles;
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ sm90_get_epilogue_smem_swizzle_layout_atom() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
|
// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
|
||||||
template <class Element, class EpilogueTileType, class Schedule>
|
template <class ElementD, class EpilogueTileType, class Schedule>
|
||||||
constexpr auto
|
constexpr auto
|
||||||
sm90_compute_tile_shape_or_override() {
|
sm90_compute_tile_shape_or_override() {
|
||||||
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||||
@ -107,7 +107,12 @@ sm90_compute_tile_shape_or_override() {
|
|||||||
return Shape<_128,_32>{};
|
return Shape<_128,_32>{};
|
||||||
}
|
}
|
||||||
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
|
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
|
||||||
return Shape<_64,_32>{};
|
if constexpr (sizeof_bits_v<ElementD> == 8) {
|
||||||
|
return Shape<_64,_64>{};
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return Shape<_64,_32>{};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule.");
|
static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule.");
|
||||||
|
|||||||
@ -34,6 +34,7 @@
|
|||||||
#include "cutlass/kernel_hardware_info.hpp"
|
#include "cutlass/kernel_hardware_info.hpp"
|
||||||
#include "cute/layout.hpp"
|
#include "cute/layout.hpp"
|
||||||
#include "cute/tensor.hpp"
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cute/arch/cluster_sm90.hpp"
|
||||||
|
|
||||||
namespace cutlass::gemm::kernel::detail {
|
namespace cutlass::gemm::kernel::detail {
|
||||||
|
|
||||||
@ -205,18 +206,14 @@ public:
|
|||||||
|
|
||||||
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
|
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
|
||||||
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
|
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
|
||||||
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
|
|
||||||
// like blockIdx and gridDim, with __CUDA_ARCH__.
|
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||||
#if defined(__CUDA_ARCH__)
|
|
||||||
if (raster_order == RasterOrder::AlongN) {
|
if (raster_order == RasterOrder::AlongN) {
|
||||||
cluster_minor_offset = blockIdx.x;
|
cluster_minor_offset = cta_m_in_cluster;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
cluster_minor_offset = blockIdx.y;
|
cluster_minor_offset = cta_n_in_cluster;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
CUTLASS_ASSERT(false && "This line should never be reached");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
uint64_t cluster_idx_minor, cluster_idx_major;
|
uint64_t cluster_idx_minor, cluster_idx_major;
|
||||||
|
|
||||||
|
|||||||
@ -141,7 +141,7 @@ public:
|
|||||||
uint32_t splits_ = 1;
|
uint32_t splits_ = 1;
|
||||||
|
|
||||||
// Number of tiled k iterations required to compute a single output tile.
|
// Number of tiled k iterations required to compute a single output tile.
|
||||||
uint32_t k_iter_per_tile_ = 0;
|
uint32_t k_tiles_per_output_tile_ = 0;
|
||||||
|
|
||||||
// Number of stream-K or split-K work units that compute an extra k iteration.
|
// Number of stream-K or split-K work units that compute an extra k iteration.
|
||||||
// This is done to handle residuals in dividing up the k iteration space.
|
// This is done to handle residuals in dividing up the k iteration space.
|
||||||
@ -160,7 +160,7 @@ public:
|
|||||||
|
|
||||||
// Number of tiled k iterations computed by each stream-K work unit. This
|
// Number of tiled k iterations computed by each stream-K work unit. This
|
||||||
// can potentially cover more than one output tile.
|
// can potentially cover more than one output tile.
|
||||||
uint32_t k_iter_per_sk_unit_ = 0;
|
uint32_t k_tiles_per_sk_unit_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Sink scheduler params as a member
|
// Sink scheduler params as a member
|
||||||
@ -189,9 +189,9 @@ public:
|
|||||||
|
|
||||||
uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l;
|
uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l;
|
||||||
|
|
||||||
// Number of k iterations each tile computes (this is just the number of k iterations
|
// Number of k tile iterations in each output tile
|
||||||
// in the problem's K dimension)
|
uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) /
|
||||||
uint32_t k_iter_per_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) / cute::size<2>(tile_shape);
|
cute::size<2>(tile_shape);
|
||||||
|
|
||||||
UnderlyingArguments underlying_args;
|
UnderlyingArguments underlying_args;
|
||||||
underlying_args.max_swizzle_size = 1;
|
underlying_args.max_swizzle_size = 1;
|
||||||
@ -216,11 +216,11 @@ public:
|
|||||||
// splits is almost certainly nonnegative here (e.g., hw_info.sm_count,
|
// splits is almost certainly nonnegative here (e.g., hw_info.sm_count,
|
||||||
// despite being an int, is a count), so it can safely be converted to unsigned
|
// despite being an int, is a count), so it can safely be converted to unsigned
|
||||||
// in the comparison to avoid a signed-unsigned comparison warning-as-error.
|
// in the comparison to avoid a signed-unsigned comparison warning-as-error.
|
||||||
splits = static_cast<decltype(k_iter_per_tile)>(splits) > k_iter_per_tile ? k_iter_per_tile : splits;
|
splits = static_cast<decltype(k_tiles_per_output_tile)>(splits) > k_tiles_per_output_tile ? k_tiles_per_output_tile : splits;
|
||||||
|
|
||||||
return get_params_basic(
|
return get_params_basic(
|
||||||
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
||||||
splits, k_iter_per_tile, reduction_workspace);
|
splits, k_tiles_per_output_tile, reduction_workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the maximum number of blocks from clusters of shape cluster_shape that we
|
// Calculate the maximum number of blocks from clusters of shape cluster_shape that we
|
||||||
@ -229,7 +229,7 @@ public:
|
|||||||
uint64_t ctas_per_wave = grid.x * grid.y;
|
uint64_t ctas_per_wave = grid.x * grid.y;
|
||||||
|
|
||||||
// The number of output tiles to be computed in stream-K and data-parallel fashion, respectively.
|
// The number of output tiles to be computed in stream-K and data-parallel fashion, respectively.
|
||||||
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave);
|
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile);
|
||||||
uint64_t dp_tiles = output_tiles - sk_tiles;
|
uint64_t dp_tiles = output_tiles - sk_tiles;
|
||||||
|
|
||||||
// Calculate the number of work units covering the data-parallel and stream-K tiles.
|
// Calculate the number of work units covering the data-parallel and stream-K tiles.
|
||||||
@ -243,7 +243,7 @@ public:
|
|||||||
uint64_t dp_units = dp_tiles;
|
uint64_t dp_units = dp_tiles;
|
||||||
|
|
||||||
// Number of k iterations computed by the stream-K units as a whole
|
// Number of k iterations computed by the stream-K units as a whole
|
||||||
uint64_t k_iter_sk_total = k_iter_per_tile * sk_tiles;
|
uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles;
|
||||||
|
|
||||||
// If there are stream-K tiles to compute and a sufficiently large number of k iterations
|
// If there are stream-K tiles to compute and a sufficiently large number of k iterations
|
||||||
// across them, they will be covered by a single wave of persistent threadblocks. Thus, there
|
// across them, they will be covered by a single wave of persistent threadblocks. Thus, there
|
||||||
@ -255,7 +255,7 @@ public:
|
|||||||
|
|
||||||
// Calculate the number of stream-K units that would be needed if each stream-K unit
|
// Calculate the number of stream-K units that would be needed if each stream-K unit
|
||||||
// computed the minimum allowable k iterations. Truncate this to be in units of clusters.
|
// computed the minimum allowable k iterations. Truncate this to be in units of clusters.
|
||||||
uint64_t min_sized_sk_units = (k_iter_sk_total / min_iters_per_sk_unit_);
|
uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_);
|
||||||
min_sized_sk_units = (min_sized_sk_units / cute::size(cluster_shape)) * cute::size(cluster_shape);
|
min_sized_sk_units = (min_sized_sk_units / cute::size(cluster_shape)) * cute::size(cluster_shape);
|
||||||
|
|
||||||
uint64_t sk_units = min(ctas_per_wave, min_sized_sk_units);
|
uint64_t sk_units = min(ctas_per_wave, min_sized_sk_units);
|
||||||
@ -264,7 +264,7 @@ public:
|
|||||||
// Short circuit to basic data-parallel decomposition
|
// Short circuit to basic data-parallel decomposition
|
||||||
return get_params_basic(
|
return get_params_basic(
|
||||||
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
||||||
1, k_iter_per_tile, reduction_workspace);
|
1, k_tiles_per_output_tile, reduction_workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the number of stream-K units is a multiple of the number of stream-K tiles, then
|
// If the number of stream-K units is a multiple of the number of stream-K tiles, then
|
||||||
@ -274,24 +274,24 @@ public:
|
|||||||
uint32_t sk_splits = static_cast<uint32_t>(sk_units / sk_tiles);
|
uint32_t sk_splits = static_cast<uint32_t>(sk_units / sk_tiles);
|
||||||
return get_params_basic(
|
return get_params_basic(
|
||||||
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape,
|
||||||
sk_splits, k_iter_per_tile, reduction_workspace);
|
sk_splits, k_tiles_per_output_tile, reduction_workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Number of k iterations computed per stream-K units
|
// Number of k iterations computed per stream-K units
|
||||||
uint64_t k_iter_per_sk_unit = k_iter_sk_total / sk_units;
|
uint64_t k_tiles_per_sk_unit = k_tiles_sk_total / sk_units;
|
||||||
|
|
||||||
// Number of stream-K units that need to compute extra iterations in order to cover
|
// Number of stream-K units that need to compute extra iterations in order to cover
|
||||||
// the residual k iterations. This assumes that each such unit computes one additional
|
// the residual k iterations. This assumes that each such unit computes one additional
|
||||||
// iteration.
|
// iteration.
|
||||||
uint64_t sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units);
|
uint64_t sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units);
|
||||||
|
|
||||||
// The division below is guaranteed to be exact because sk_big_units is guaranteed
|
// The division below is guaranteed to be exact because sk_big_units is guaranteed
|
||||||
// to be a multiple of cluster_size (cute::size(cluster_shape)). This is useful because
|
// to be a multiple of cluster_size (cute::size(cluster_shape)). This is useful because
|
||||||
// it allows us to use a block's linearized cluster ID to determine whether it is
|
// it allows us to use a block's linearized cluster ID to determine whether it is
|
||||||
// a big block. The reasoning behind this guarnatee is explained as follows:
|
// a big block. The reasoning behind this guarnatee is explained as follows:
|
||||||
// sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units);
|
// sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units);
|
||||||
//
|
//
|
||||||
// - k_iter_sk_total is a multiple of cluster_size because it is the product
|
// - k_tiles_sk_total is a multiple of cluster_size because it is the product
|
||||||
// of number of tail tiles and the number of k iterations per tile. Because
|
// of number of tail tiles and the number of k iterations per tile. Because
|
||||||
// both the number of output tiles and number of available SMs are rounded
|
// both the number of output tiles and number of available SMs are rounded
|
||||||
// to be multiples of cluster shape, the number of tail tiles
|
// to be multiples of cluster shape, the number of tail tiles
|
||||||
@ -313,12 +313,12 @@ public:
|
|||||||
underlying_params.raster_order_,
|
underlying_params.raster_order_,
|
||||||
cluster_shape,
|
cluster_shape,
|
||||||
1, // Static k-splitting factor. Unused for stream-K.
|
1, // Static k-splitting factor. Unused for stream-K.
|
||||||
k_iter_per_tile,
|
k_tiles_per_output_tile,
|
||||||
static_cast<uint32_t>(sk_big_units_per_cluster),
|
static_cast<uint32_t>(sk_big_units_per_cluster),
|
||||||
reduction_workspace,
|
reduction_workspace,
|
||||||
sk_tiles,
|
sk_tiles,
|
||||||
static_cast<uint32_t>(sk_units),
|
static_cast<uint32_t>(sk_units),
|
||||||
static_cast<uint32_t>(k_iter_per_sk_unit)
|
static_cast<uint32_t>(k_tiles_per_sk_unit)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,105 +338,32 @@ public:
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
WorkTileInfo
|
WorkTileInfo
|
||||||
get_current_work() const {
|
get_current_work() const {
|
||||||
return get_current_work_for_linear_idx(current_work_linear_idx_);
|
return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
WorkTileInfo
|
static WorkTileInfo
|
||||||
get_current_work_for_linear_idx(uint64_t linear_idx) const {
|
get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) {
|
||||||
if (linear_idx >= scheduler_params.units_per_problem_) {
|
if (linear_idx >= params.units_per_problem_) {
|
||||||
// Invalid work. Return an empty result.
|
// Invalid work. Return an empty result.
|
||||||
return {0, 0, 0, 0, false, 0};
|
return {0, 0, 0, 0, false, 0};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine whether this work unit is a data-parallel or stream-K work unit
|
// Determine whether this work unit is a data-parallel or stream-K work unit
|
||||||
bool is_stream_k_unit = linear_idx < scheduler_params.sk_units_;
|
bool is_stream_k_unit = linear_idx < params.sk_units_;
|
||||||
|
|
||||||
bool is_split_k = scheduler_params.splits_ > 1;
|
bool is_split_k = params.splits_ > 1;
|
||||||
|
|
||||||
// Bypass the stream-K scheduling logic for basic data-parallel or split-K work
|
|
||||||
if (is_split_k || !is_stream_k_unit) {
|
if (is_split_k || !is_stream_k_unit) {
|
||||||
// The linearized ID space is in terms of work units, rather than tiles. However,
|
// Bypass the stream-K scheduling logic for basic data-parallel or split-K work
|
||||||
// to compute the correct block offset for a data-parallel tile, we must convert
|
return set_non_stream_k_work(linear_idx, params, is_split_k);
|
||||||
// the current ID to the data-parallel tile it corresponds to. Each data-parallel
|
}
|
||||||
// unit maps to a single data-parallel tile, but each stream-K unit can map to more
|
else {
|
||||||
// than one tile. Thus, we must offset the work-unit ID among the data-parallel units
|
// This is a stream-K work unit
|
||||||
// by the total number of output tiles that will be computed by stream-K units.
|
WorkTileInfo work_tile_info;
|
||||||
//
|
set_stream_k_work(params, linear_idx, work_tile_info, /*new_unit = */ true);
|
||||||
// The logic below also works for the split-K case, in which sk_units_ and sk_tiles_
|
return work_tile_info;
|
||||||
// are each 0.
|
|
||||||
uint64_t linear_work_idx = linear_idx - scheduler_params.sk_units_ + scheduler_params.sk_tiles_;
|
|
||||||
|
|
||||||
// Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices
|
|
||||||
uint64_t work_idx_l, remainder;
|
|
||||||
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_work_idx);
|
|
||||||
|
|
||||||
uint64_t work_idx_k = 0;
|
|
||||||
if (is_split_k) {
|
|
||||||
scheduler_params.divmod_k_(work_idx_k, remainder, remainder);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t cta_per_grid_dim, dontcare;
|
|
||||||
scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
|
|
||||||
|
|
||||||
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
|
|
||||||
cta_per_grid_dim,
|
|
||||||
scheduler_params.divmod_cluster_shape_major_,
|
|
||||||
scheduler_params.divmod_cluster_shape_minor_,
|
|
||||||
scheduler_params.divmod_cluster_blk_major_,
|
|
||||||
scheduler_params.log_swizzle_size_,
|
|
||||||
scheduler_params.raster_order_);
|
|
||||||
|
|
||||||
bool is_final_split = (work_idx_k == scheduler_params.splits_ - 1);
|
|
||||||
|
|
||||||
uint32_t k_iter = scheduler_params.k_iter_per_tile_;
|
|
||||||
if (is_split_k) {
|
|
||||||
// Determine the number of iterations and starting iteration of this split.
|
|
||||||
// Doing so requires accounting for residual iterations, which are handled
|
|
||||||
// by the first big_units_ splits (with big_units_ = tiles % sm_count).
|
|
||||||
|
|
||||||
// Offsets for "normal" units. No additional k iterations are performed,
|
|
||||||
// and big_units_ "big" units preceded us, each of which performed one
|
|
||||||
// additional iteration. Thus, we must increase our split starting offset
|
|
||||||
// by big_units_.
|
|
||||||
int additional_k_iter = 0;
|
|
||||||
int split_start_offset = scheduler_params.big_units_;
|
|
||||||
|
|
||||||
if (work_idx_k < scheduler_params.big_units_) {
|
|
||||||
// Offsets for "big" units. One additional k iteration is performed,
|
|
||||||
// and each split preceding us was a big unit, so we must increase
|
|
||||||
// our split starting offset by our split ID (work_idx_k).
|
|
||||||
additional_k_iter = 1;
|
|
||||||
split_start_offset = work_idx_k;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up k iteration count and split starting iteration assuming the
|
|
||||||
// iteration space is evenly split.
|
|
||||||
k_iter /= scheduler_params.splits_;
|
|
||||||
work_idx_k *= k_iter;
|
|
||||||
|
|
||||||
// Apply any fixup needed to handle residuals
|
|
||||||
work_idx_k += split_start_offset;
|
|
||||||
k_iter += additional_k_iter;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
work_idx_m,
|
|
||||||
work_idx_n,
|
|
||||||
static_cast<int32_t>(work_idx_k),
|
|
||||||
static_cast<int32_t>(work_idx_l),
|
|
||||||
true,
|
|
||||||
scheduler_params.k_iter_per_tile_,
|
|
||||||
k_iter,
|
|
||||||
k_iter, // remaining iterations
|
|
||||||
is_final_split
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a stream-K work unit
|
|
||||||
WorkTileInfo work_tile_info;
|
|
||||||
set_stream_k_work(linear_idx, work_tile_info, /*new_unit = */ true);
|
|
||||||
return work_tile_info;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether the current work_tile_info passed in should continue to be used. This
|
// Returns whether the current work_tile_info passed in should continue to be used. This
|
||||||
@ -446,13 +373,24 @@ public:
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
bool
|
bool
|
||||||
continue_current_work(WorkTileInfo& work_tile_info) const {
|
continue_current_work(WorkTileInfo& work_tile_info) const {
|
||||||
|
return continue_current_work_for_linear_idx(
|
||||||
|
current_work_linear_idx_, work_tile_info, scheduler_params);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE static
|
||||||
|
bool
|
||||||
|
continue_current_work_for_linear_idx(
|
||||||
|
uint64_t linear_idx,
|
||||||
|
WorkTileInfo& work_tile_info,
|
||||||
|
Params const& params) {
|
||||||
|
|
||||||
work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count;
|
work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count;
|
||||||
|
|
||||||
if (work_tile_info.k_tile_remaining == 0) {
|
if (work_tile_info.k_tile_remaining == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
set_stream_k_work(current_work_linear_idx_, work_tile_info, /* new_unit = */ false);
|
set_stream_k_work(params, linear_idx, work_tile_info, /* new_unit = */ false);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -495,6 +433,14 @@ public:
|
|||||||
/*truncate_by_problem_size=*/false);
|
/*truncate_by_problem_size=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns whether fixup is needed for `work_tile_info`.
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
static bool
|
||||||
|
requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) {
|
||||||
|
// Fixup is not needed for data-parallel tiles
|
||||||
|
return work_tile_info.k_tile_count != params.k_tiles_per_output_tile_;
|
||||||
|
}
|
||||||
|
|
||||||
// Performs the reduction across splits for a given output tile.
|
// Performs the reduction across splits for a given output tile.
|
||||||
template <class FrgTensorC>
|
template <class FrgTensorC>
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
@ -505,13 +451,25 @@ public:
|
|||||||
FrgTensorC& accumulators,
|
FrgTensorC& accumulators,
|
||||||
uint32_t num_barriers,
|
uint32_t num_barriers,
|
||||||
uint32_t barrier_idx) {
|
uint32_t barrier_idx) {
|
||||||
|
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, 2>;
|
||||||
|
return fixup_helper<FrgTensorC, BarrierManager>(
|
||||||
|
params, work_tile_info, accumulators, num_barriers, barrier_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper for performing the reduction across splits for a given output tile.
|
||||||
|
template <class FrgTensorC, class BarrierManager>
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void
|
||||||
|
fixup_helper(
|
||||||
|
Params const& params,
|
||||||
|
WorkTileInfo const& work_tile_info,
|
||||||
|
FrgTensorC& accumulators,
|
||||||
|
uint32_t num_barriers,
|
||||||
|
uint32_t barrier_idx) {
|
||||||
|
|
||||||
using ElementAccumulator = typename FrgTensorC::value_type;
|
using ElementAccumulator = typename FrgTensorC::value_type;
|
||||||
|
|
||||||
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, 2>;
|
if (!requires_fixup(params, work_tile_info)) {
|
||||||
|
|
||||||
if (work_tile_info.k_tile_count == params.k_iter_per_tile_) {
|
|
||||||
// Fixup is not needed for data-parallel tiles
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -619,21 +577,23 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||||
|
|
||||||
uint64_t cta_per_grid_dim;
|
uint64_t cta_per_grid_dim;
|
||||||
uint64_t cluster_dim_idx;
|
uint64_t cluster_dim_idx;
|
||||||
if (params.raster_order_ == RasterOrder::AlongN) {
|
if (params.raster_order_ == RasterOrder::AlongN) {
|
||||||
uint64_t block_idx_m = (work_tile_info.M_idx - blockIdx.x) / gridDim.x;
|
uint64_t block_idx_m = (work_tile_info.M_idx - cta_m_in_cluster) / cute::size<0>(params.cluster_shape_);
|
||||||
uint64_t block_idx_n = work_tile_info.N_idx;
|
uint64_t block_idx_n = work_tile_info.N_idx;
|
||||||
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
|
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
|
||||||
params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n;
|
params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n;
|
||||||
cluster_dim_idx = blockIdx.x;
|
cluster_dim_idx = cta_m_in_cluster;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
uint64_t block_idx_m = work_tile_info.M_idx;
|
uint64_t block_idx_m = work_tile_info.M_idx;
|
||||||
uint64_t block_idx_n = (work_tile_info.N_idx - blockIdx.y) / gridDim.y;
|
uint64_t block_idx_n = (work_tile_info.N_idx - cta_n_in_cluster) / cute::size<1>(params.cluster_shape_);
|
||||||
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
|
cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor *
|
||||||
params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m;
|
params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m;
|
||||||
cluster_dim_idx = blockIdx.y;
|
cluster_dim_idx = cta_n_in_cluster;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim;
|
uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim;
|
||||||
@ -646,7 +606,7 @@ public:
|
|||||||
get_workspace_size(
|
get_workspace_size(
|
||||||
Arguments const& args,
|
Arguments const& args,
|
||||||
ProblemShape problem_shape,
|
ProblemShape problem_shape,
|
||||||
KernelHardwareInfo const& hw_info,
|
KernelHardwareInfo const& hw_info,
|
||||||
uint32_t mma_warp_groups) {
|
uint32_t mma_warp_groups) {
|
||||||
|
|
||||||
int barrier_workspace_size = 0;
|
int barrier_workspace_size = 0;
|
||||||
@ -715,7 +675,7 @@ private:
|
|||||||
|
|
||||||
// Construct a layout for the indexed tensor. The main purpose of this new layout is to
|
// Construct a layout for the indexed tensor. The main purpose of this new layout is to
|
||||||
// override the k extent to support cases in which the split computes a number of iterations
|
// override the k extent to support cases in which the split computes a number of iterations
|
||||||
// not equal to total_tile_k_iter / splits. A common example of this is in stream-K is when a
|
// not equal to total_k_tiles / splits. A common example of this is in stream-K is when a
|
||||||
// unit computes the final 20 of the total 32 k iterations of the output tile. In this case,
|
// unit computes the final 20 of the total 32 k iterations of the output tile. In this case,
|
||||||
// set splits = 32 and the split index (K_idx) to 11. The zipped divide above results in each
|
// set splits = 32 and the split index (K_idx) to 11. The zipped divide above results in each
|
||||||
// of the splits computing only one k iteration.
|
// of the splits computing only one k iteration.
|
||||||
@ -728,12 +688,13 @@ private:
|
|||||||
// Returns the number of stream-K tiles that will be computed amongst `output_tiles` total
|
// Returns the number of stream-K tiles that will be computed amongst `output_tiles` total
|
||||||
// output tiles on a device with `ctas_per_wave` CTAs in each wave.
|
// output tiles on a device with `ctas_per_wave` CTAs in each wave.
|
||||||
static uint32_t
|
static uint32_t
|
||||||
get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave) {
|
get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave, uint32_t k_tiles_per_output_tile) {
|
||||||
uint32_t full_waves = static_cast<uint32_t>(output_tiles / ctas_per_wave);
|
uint32_t full_waves = static_cast<uint32_t>(output_tiles / ctas_per_wave);
|
||||||
uint32_t total_waves = static_cast<uint32_t>((output_tiles + ctas_per_wave - 1) / ctas_per_wave);
|
uint32_t total_waves = static_cast<uint32_t>((output_tiles + ctas_per_wave - 1) / ctas_per_wave);
|
||||||
|
|
||||||
if (full_waves == total_waves) {
|
if (full_waves == total_waves || k_tiles_per_output_tile == 1) {
|
||||||
// No quantization. All tiles will be data-parallel tiles.
|
// All tiles will be data-parallel tiles if there is either no quantization
|
||||||
|
// or if there is no work to be split.
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -811,9 +772,12 @@ private:
|
|||||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(TileShape{}) - 1) /
|
||||||
|
cute::size<2>(TileShape{});
|
||||||
|
|
||||||
dim3 grid = get_grid_shape(problem_shape_mnkl, TileShape{}, cluster_shape, {0, sm_count}, args);
|
dim3 grid = get_grid_shape(problem_shape_mnkl, TileShape{}, cluster_shape, {0, sm_count}, args);
|
||||||
uint64_t ctas_per_wave = grid.x * grid.y;
|
uint64_t ctas_per_wave = grid.x * grid.y;
|
||||||
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave);
|
uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile);
|
||||||
|
|
||||||
barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups);
|
barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups);
|
||||||
reduction_workspace_size = get_reduction_workspace_size<ElementAccumulator>(sk_tiles);
|
reduction_workspace_size = get_reduction_workspace_size<ElementAccumulator>(sk_tiles);
|
||||||
@ -829,10 +793,10 @@ private:
|
|||||||
uint32_t blocks_l,
|
uint32_t blocks_l,
|
||||||
ClusterShape cluster_shape,
|
ClusterShape cluster_shape,
|
||||||
uint32_t splits,
|
uint32_t splits,
|
||||||
uint32_t k_iter_per_tile,
|
uint32_t k_tiles_per_output_tile,
|
||||||
void* reduction_workspace) {
|
void* reduction_workspace) {
|
||||||
|
|
||||||
uint32_t big_units = k_iter_per_tile % splits;
|
uint32_t big_units = k_tiles_per_output_tile % splits;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
underlying_params.divmod_cluster_shape_major_,
|
underlying_params.divmod_cluster_shape_major_,
|
||||||
@ -845,7 +809,7 @@ private:
|
|||||||
underlying_params.raster_order_,
|
underlying_params.raster_order_,
|
||||||
cluster_shape,
|
cluster_shape,
|
||||||
splits,
|
splits,
|
||||||
k_iter_per_tile,
|
k_tiles_per_output_tile,
|
||||||
big_units,
|
big_units,
|
||||||
reduction_workspace
|
reduction_workspace
|
||||||
};
|
};
|
||||||
@ -855,8 +819,12 @@ private:
|
|||||||
// is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining
|
// is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining
|
||||||
// iterations) is used to find the next tile in the current work unit.
|
// iterations) is used to find the next tile in the current work unit.
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void
|
static void
|
||||||
set_stream_k_work(uint64_t linear_idx, WorkTileInfo& work_tile_info, bool new_unit) const {
|
set_stream_k_work(
|
||||||
|
Params const& params,
|
||||||
|
uint64_t linear_idx,
|
||||||
|
WorkTileInfo& work_tile_info,
|
||||||
|
bool new_unit) {
|
||||||
// In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K
|
// In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K
|
||||||
// threadblock individually. For the most part, the set of K iterations corresponding to stream-K
|
// threadblock individually. For the most part, the set of K iterations corresponding to stream-K
|
||||||
// work was divided amongst stream-K threadblocks, and a threadblock determined which tile
|
// work was divided amongst stream-K threadblocks, and a threadblock determined which tile
|
||||||
@ -872,15 +840,15 @@ private:
|
|||||||
//
|
//
|
||||||
// To do so, we divide up the linearized stream-K units into clusters and share the same K
|
// To do so, we divide up the linearized stream-K units into clusters and share the same K
|
||||||
// offsets for work within clusters.
|
// offsets for work within clusters.
|
||||||
auto cluster_linear_work_idx = linear_idx / size(scheduler_params.cluster_shape_);
|
auto cluster_linear_work_idx = linear_idx / size(params.cluster_shape_);
|
||||||
|
|
||||||
// Determine the starting k iteration computed by this stream-K work unit
|
// Determine the starting k iteration computed by this stream-K work unit
|
||||||
uint32_t unit_iter_start = scheduler_params.k_iter_per_sk_unit_ * cluster_linear_work_idx;
|
uint32_t unit_iter_start = params.k_tiles_per_sk_unit_ * cluster_linear_work_idx;
|
||||||
|
|
||||||
// Adjust the starting position and number of k iterations for "big units," which
|
// Adjust the starting position and number of k iterations for "big units," which
|
||||||
// compute one extra iteration. These are the first big_units_ units in the
|
// compute one extra iteration. These are the first big_units_ units in the
|
||||||
// linearized ID space.
|
// linearized ID space.
|
||||||
bool is_big_unit = cluster_linear_work_idx < scheduler_params.big_units_;
|
bool is_big_unit = cluster_linear_work_idx < params.big_units_;
|
||||||
if (is_big_unit) {
|
if (is_big_unit) {
|
||||||
// Since the "big units" are the first units in the linearized ID space, each
|
// Since the "big units" are the first units in the linearized ID space, each
|
||||||
// of the units preceding this big unit computed one extra iteration. Thus,
|
// of the units preceding this big unit computed one extra iteration. Thus,
|
||||||
@ -889,16 +857,16 @@ private:
|
|||||||
unit_iter_start += cluster_linear_work_idx;
|
unit_iter_start += cluster_linear_work_idx;
|
||||||
} else {
|
} else {
|
||||||
// Increment by one for each of the big clusters (since all big units precede this unit)
|
// Increment by one for each of the big clusters (since all big units precede this unit)
|
||||||
unit_iter_start += scheduler_params.big_units_;
|
unit_iter_start += params.big_units_;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t unit_iters;
|
uint32_t unit_iters;
|
||||||
if (new_unit) {
|
if (new_unit) {
|
||||||
unit_iters = scheduler_params.k_iter_per_sk_unit_;
|
unit_iters = params.k_tiles_per_sk_unit_;
|
||||||
|
|
||||||
// Only adjust iteration count for big unit if we are initializing this
|
// Only adjust iteration count for big unit if we are initializing this
|
||||||
// work unit. For existing work units, the extra iteration for big units
|
// work unit. For existing work units, the extra iteration for big units
|
||||||
// has already been accounted for in k_iter_reamaining
|
// has already been accounted for in k_tiles_reamaining
|
||||||
if (is_big_unit) {
|
if (is_big_unit) {
|
||||||
++unit_iters;
|
++unit_iters;
|
||||||
}
|
}
|
||||||
@ -917,22 +885,21 @@ private:
|
|||||||
// for them to be computed later, so as to reduce the likelihood of blocking
|
// for them to be computed later, so as to reduce the likelihood of blocking
|
||||||
// on other work.
|
// on other work.
|
||||||
uint32_t unit_iter_end = unit_iter_start + unit_iters - 1;
|
uint32_t unit_iter_end = unit_iter_start + unit_iters - 1;
|
||||||
uint32_t true_tile_id = unit_iter_end / scheduler_params.k_iter_per_tile_;
|
uint32_t true_tile_id = unit_iter_end / params.k_tiles_per_output_tile_;
|
||||||
uint32_t true_tile_iter_start = true_tile_id * scheduler_params.k_iter_per_tile_;
|
uint32_t true_tile_iter_start = true_tile_id * params.k_tiles_per_output_tile_;
|
||||||
uint32_t true_tile_iter_end = true_tile_iter_start + scheduler_params.k_iter_per_tile_;
|
uint32_t true_tile_iter_end = true_tile_iter_start + params.k_tiles_per_output_tile_;
|
||||||
|
|
||||||
// Bring the linearized tile ID back into the space of tiles, rather than clusters
|
// Bring the linearized tile ID back into the space of tiles, rather than clusters
|
||||||
true_tile_id *= size(scheduler_params.cluster_shape_);
|
true_tile_id *= size(params.cluster_shape_);
|
||||||
|
|
||||||
auto cluster_dim0 = cute::size<0>(scheduler_params.cluster_shape_);
|
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||||
auto cluster_dim1 = cute::size<1>(scheduler_params.cluster_shape_);
|
|
||||||
|
|
||||||
// The final linearized tile ID is in units of the cluster dimension over which we rasterize.
|
// The final linearized tile ID is in units of the cluster dimension over which we rasterize.
|
||||||
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
|
if (params.raster_order_ == RasterOrder::AlongN) {
|
||||||
true_tile_id += (blockIdx.y % cluster_dim1) * cluster_dim0;
|
true_tile_id += cta_n_in_cluster * cute::size<0>(params.cluster_shape_);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
true_tile_id += (blockIdx.x % cluster_dim0) * cluster_dim1;
|
true_tile_id += cta_m_in_cluster * cute::size<1>(params.cluster_shape_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The unit's starting k iteration in the current tile is either the starting
|
// The unit's starting k iteration in the current tile is either the starting
|
||||||
@ -948,19 +915,18 @@ private:
|
|||||||
uint32_t tile_iters = tile_iter_end - tile_iter_start;
|
uint32_t tile_iters = tile_iter_end - tile_iter_start;
|
||||||
|
|
||||||
uint64_t work_idx_l, remainder;
|
uint64_t work_idx_l, remainder;
|
||||||
scheduler_params.divmod_batch_(work_idx_l, remainder, true_tile_id);
|
params.divmod_batch_(work_idx_l, remainder, true_tile_id);
|
||||||
|
|
||||||
uint64_t cta_per_grid_dim, dontcare;
|
uint64_t cta_per_grid_dim, dontcare;
|
||||||
scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
|
params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
|
||||||
|
|
||||||
|
|
||||||
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
|
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
|
||||||
cta_per_grid_dim,
|
cta_per_grid_dim,
|
||||||
scheduler_params.divmod_cluster_shape_major_,
|
params.divmod_cluster_shape_major_,
|
||||||
scheduler_params.divmod_cluster_shape_minor_,
|
params.divmod_cluster_shape_minor_,
|
||||||
scheduler_params.divmod_cluster_blk_major_,
|
params.divmod_cluster_blk_major_,
|
||||||
scheduler_params.log_swizzle_size_,
|
params.log_swizzle_size_,
|
||||||
scheduler_params.raster_order_);
|
params.raster_order_);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Update the work_tile_info
|
// Update the work_tile_info
|
||||||
@ -971,11 +937,11 @@ private:
|
|||||||
work_tile_info.N_idx = work_idx_n;
|
work_tile_info.N_idx = work_idx_n;
|
||||||
work_tile_info.L_idx = static_cast<int32_t>(work_idx_l);
|
work_tile_info.L_idx = static_cast<int32_t>(work_idx_l);
|
||||||
|
|
||||||
// Set the k offset to be the starting k iteration for this tile
|
// Set the k offset to be the starting k tile for this output tile
|
||||||
work_tile_info.K_idx = static_cast<int32_t>(tile_iter_start - true_tile_iter_start);
|
work_tile_info.K_idx = static_cast<int32_t>(tile_iter_start - true_tile_iter_start);
|
||||||
|
|
||||||
// Set the split count to be the number of k iterations in the tile
|
// Set the split count to be the number of k tiles in the output tile
|
||||||
work_tile_info.splits = scheduler_params.k_iter_per_tile_;
|
work_tile_info.splits = params.k_tiles_per_output_tile_;
|
||||||
|
|
||||||
// Any checks for invalid work units should be done prior to this call
|
// Any checks for invalid work units should be done prior to this call
|
||||||
work_tile_info.is_valid_tile = true;
|
work_tile_info.is_valid_tile = true;
|
||||||
@ -987,6 +953,89 @@ private:
|
|||||||
// the output tile in question
|
// the output tile in question
|
||||||
work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end);
|
work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a WorkTileInfo to be computed for either the data-parallel or split-K
|
||||||
|
// work unit identified by the provided linear ID.
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static WorkTileInfo
|
||||||
|
set_non_stream_k_work(uint64_t linear_idx, Params const& params, bool is_split_k) {
|
||||||
|
|
||||||
|
// The linearized ID space is in terms of work units, rather than tiles. However,
|
||||||
|
// to compute the correct block offset for a data-parallel tile, we must convert
|
||||||
|
// the current ID to the data-parallel tile it corresponds to. Each data-parallel
|
||||||
|
// unit maps to a single data-parallel tile, but each stream-K unit can map to more
|
||||||
|
// than one tile. Thus, we must offset the work-unit ID among the data-parallel units
|
||||||
|
// by the total number of output tiles that will be computed by stream-K units.
|
||||||
|
//
|
||||||
|
// The logic below also works for the split-K case, in which sk_units_ and sk_tiles_
|
||||||
|
// are each 0.
|
||||||
|
uint64_t linear_work_idx = linear_idx - params.sk_units_ + params.sk_tiles_;
|
||||||
|
|
||||||
|
// Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices
|
||||||
|
uint64_t work_idx_l, remainder;
|
||||||
|
params.divmod_batch_(work_idx_l, remainder, linear_work_idx);
|
||||||
|
|
||||||
|
uint64_t work_idx_k = 0;
|
||||||
|
if (is_split_k) {
|
||||||
|
params.divmod_k_(work_idx_k, remainder, remainder);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t cta_per_grid_dim, dontcare;
|
||||||
|
params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder);
|
||||||
|
|
||||||
|
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
|
||||||
|
cta_per_grid_dim,
|
||||||
|
params.divmod_cluster_shape_major_,
|
||||||
|
params.divmod_cluster_shape_minor_,
|
||||||
|
params.divmod_cluster_blk_major_,
|
||||||
|
params.log_swizzle_size_,
|
||||||
|
params.raster_order_);
|
||||||
|
|
||||||
|
bool is_final_split = (work_idx_k == params.splits_ - 1);
|
||||||
|
|
||||||
|
uint32_t k_tiles = params.k_tiles_per_output_tile_;
|
||||||
|
if (is_split_k) {
|
||||||
|
// Determine the number of iterations and starting iteration of this split.
|
||||||
|
// Doing so requires accounting for residual iterations, which are handled
|
||||||
|
// by the first big_units_ splits (with big_units_ = tiles % sm_count).
|
||||||
|
|
||||||
|
// Offsets for "normal" units. No additional k iterations are performed,
|
||||||
|
// and big_units_ "big" units preceded us, each of which performed one
|
||||||
|
// additional iteration. Thus, we must increase our split starting offset
|
||||||
|
// by big_units_.
|
||||||
|
int additional_k_tiles = 0;
|
||||||
|
int split_start_offset = params.big_units_;
|
||||||
|
|
||||||
|
if (work_idx_k < params.big_units_) {
|
||||||
|
// Offsets for "big" units. One additional k iteration is performed,
|
||||||
|
// and each split preceding us was a big unit, so we must increase
|
||||||
|
// our split starting offset by our split ID (work_idx_k).
|
||||||
|
additional_k_tiles = 1;
|
||||||
|
split_start_offset = work_idx_k;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up k iteration count and split starting iteration assuming the
|
||||||
|
// iteration space is evenly split.
|
||||||
|
k_tiles /= params.splits_;
|
||||||
|
work_idx_k *= k_tiles;
|
||||||
|
|
||||||
|
// Apply any fixup needed to handle residuals
|
||||||
|
work_idx_k += split_start_offset;
|
||||||
|
k_tiles += additional_k_tiles;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
work_idx_m,
|
||||||
|
work_idx_n,
|
||||||
|
static_cast<int32_t>(work_idx_k),
|
||||||
|
static_cast<int32_t>(work_idx_l),
|
||||||
|
true,
|
||||||
|
params.k_tiles_per_output_tile_,
|
||||||
|
k_tiles,
|
||||||
|
k_tiles, // remaining iterations
|
||||||
|
is_final_split
|
||||||
|
};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace cutlass::gemm::kernel::detail
|
} // namespace cutlass::gemm::kernel::detail
|
||||||
|
|||||||
@ -28,7 +28,7 @@
|
|||||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
*
|
*
|
||||||
**************************************************************************************************/
|
**************************************************************************************************/
|
||||||
/*!
|
/*!
|
||||||
\file
|
\file
|
||||||
\brief Boost-like numeric conversion operator for CUTLASS numeric types
|
\brief Boost-like numeric conversion operator for CUTLASS numeric types
|
||||||
*/
|
*/
|
||||||
@ -55,7 +55,7 @@ enum class FloatRoundStyle {
|
|||||||
round_indeterminate, ///< rounding mode unknown
|
round_indeterminate, ///< rounding mode unknown
|
||||||
round_toward_zero, ///< round toward zero
|
round_toward_zero, ///< round toward zero
|
||||||
round_to_nearest, ///< round to nearest even
|
round_to_nearest, ///< round to nearest even
|
||||||
round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type
|
round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type
|
||||||
round_toward_infinity, ///< round toward infinity
|
round_toward_infinity, ///< round toward infinity
|
||||||
round_toward_neg_infinity, ///< round toward negative infinity
|
round_toward_neg_infinity, ///< round toward negative infinity
|
||||||
round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero
|
round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero
|
||||||
@ -561,7 +561,7 @@ struct NumericConverter<tfloat32_t, float, FloatRoundStyle::round_to_nearest> {
|
|||||||
|
|
||||||
// Note, the following is intentionally commented out. TF32
|
// Note, the following is intentionally commented out. TF32
|
||||||
// does not define the low order bits, so they may be left in
|
// does not define the low order bits, so they may be left in
|
||||||
// an undefined state.
|
// an undefined state.
|
||||||
//
|
//
|
||||||
// By not truncating these bit explicitly, we avoid an extra logical
|
// By not truncating these bit explicitly, we avoid an extra logical
|
||||||
// operation.
|
// operation.
|
||||||
@ -657,7 +657,7 @@ template <
|
|||||||
struct NumericConverterFastF32 {
|
struct NumericConverterFastF32 {
|
||||||
|
|
||||||
// result_type holds big tfloat32_t at idx(0) and small tfloat32_t at idx(1)
|
// result_type holds big tfloat32_t at idx(0) and small tfloat32_t at idx(1)
|
||||||
using result_type = Array<tfloat32_t, 2>;
|
using result_type = Array<tfloat32_t, 2>;
|
||||||
|
|
||||||
// source data type
|
// source data type
|
||||||
using source_type = float;
|
using source_type = float;
|
||||||
@ -708,7 +708,7 @@ struct NumericConverterClamp {
|
|||||||
NumericConverter<result_type, source_type> convert_op;
|
NumericConverter<result_type, source_type> convert_op;
|
||||||
result_type const kClamp_max = platform::numeric_limits<result_type>::max();
|
result_type const kClamp_max = platform::numeric_limits<result_type>::max();
|
||||||
result_type const kClamp_min = platform::numeric_limits<result_type>::lowest();
|
result_type const kClamp_min = platform::numeric_limits<result_type>::lowest();
|
||||||
if (s < (source_type)kClamp_min)
|
if (s < (source_type)kClamp_min)
|
||||||
return kClamp_min;
|
return kClamp_min;
|
||||||
if (s > (source_type)kClamp_max)
|
if (s > (source_type)kClamp_max)
|
||||||
return kClamp_max;
|
return kClamp_max;
|
||||||
@ -848,7 +848,7 @@ struct NumericArrayConverter<half_t, float, 2, FloatRoundStyle::round_to_nearest
|
|||||||
result[0] = convert_(source[0]);
|
result[0] = convert_(source[0]);
|
||||||
result[1] = convert_(source[1]);
|
result[1] = convert_(source[1]);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -878,7 +878,7 @@ struct NumericArrayConverter<float, half_t, 2, Round> {
|
|||||||
result[0] = convert_(source[0]);
|
result[0] = convert_(source[0]);
|
||||||
result[1] = convert_(source[1]);
|
result[1] = convert_(source[1]);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1044,7 +1044,7 @@ struct NumericArrayConverter<bfloat16_t, float, N, Round> {
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// Conditional guards to enable partial specialization for packed integers
|
// Conditional guards to enable partial specialization for packed integers
|
||||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
|
||||||
((__CUDACC_VER_MAJOR__ > 10) || \
|
((__CUDACC_VER_MAJOR__ > 10) || \
|
||||||
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||||
@ -1066,7 +1066,7 @@ struct NumericArrayConverter<int8_t, int, 1, Round> {
|
|||||||
result_type result;
|
result_type result;
|
||||||
|
|
||||||
result[0] = convert_element_(source[0]);
|
result[0] = convert_element_(source[0]);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1189,7 +1189,7 @@ struct NumericArrayConverter<uint8_t, int, 1, Round> {
|
|||||||
result_type result;
|
result_type result;
|
||||||
|
|
||||||
result[0] = convert_element_(source[0]);
|
result[0] = convert_element_(source[0]);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -215,10 +215,17 @@ class GemmArguments2x(ArgumentBase):
|
|||||||
else:
|
else:
|
||||||
self.batch_count = 1
|
self.batch_count = 1
|
||||||
|
|
||||||
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
|
if "batch_strides" in kwargs:
|
||||||
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
|
self.batched_stride_A = kwargs["batch_strides"]["A"]
|
||||||
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
|
self.batched_stride_B = kwargs["batch_strides"]["B"]
|
||||||
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
|
self.batched_stride_C = kwargs["batch_strides"]["C"]
|
||||||
|
self.batched_stride_D = kwargs["batch_strides"]["D"]
|
||||||
|
else:
|
||||||
|
self.batched_stride_A = self.problem_size.m() * self.problem_size.k()
|
||||||
|
self.batched_stride_B = self.problem_size.n() * self.problem_size.k()
|
||||||
|
self.batched_stride_C = self.problem_size.m() * self.problem_size.n()
|
||||||
|
self.batched_stride_D = self.problem_size.m() * self.problem_size.n()
|
||||||
|
|
||||||
if self.bias:
|
if self.bias:
|
||||||
self.batched_stride_C = self.problem_size.n()
|
self.batched_stride_C = self.problem_size.n()
|
||||||
|
|
||||||
|
|||||||
@ -132,9 +132,9 @@ class KernelsForDataType:
|
|||||||
"""
|
"""
|
||||||
# Determine the leading dimension of the shape
|
# Determine the leading dimension of the shape
|
||||||
if layout == cutlass.LayoutType.ColumnMajor:
|
if layout == cutlass.LayoutType.ColumnMajor:
|
||||||
ld = shape[0]
|
ld = shape[-2]
|
||||||
elif layout == cutlass.LayoutType.RowMajor:
|
elif layout == cutlass.LayoutType.RowMajor:
|
||||||
ld = shape[1]
|
ld = shape[-1]
|
||||||
elif layout == cutlass.LayoutType.TensorNHWC:
|
elif layout == cutlass.LayoutType.TensorNHWC:
|
||||||
ld = shape[-1]
|
ld = shape[-1]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -114,6 +114,8 @@
|
|||||||
args.sync()
|
args.sync()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from math import prod
|
||||||
|
|
||||||
import cutlass_bindings
|
import cutlass_bindings
|
||||||
|
|
||||||
import cutlass
|
import cutlass
|
||||||
@ -442,6 +444,113 @@ class Gemm(OperationBase):
|
|||||||
compiler.add_module([self.operation,])
|
compiler.add_module([self.operation,])
|
||||||
return self.operation
|
return self.operation
|
||||||
|
|
||||||
|
def _verify_rank(self, tensor):
|
||||||
|
"""
|
||||||
|
Verifies that ``tensor`` has rank greater than 1
|
||||||
|
|
||||||
|
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||||
|
:type tensor: numpy/cupy/torch array/tensor object
|
||||||
|
"""
|
||||||
|
if len(tensor.shape) < 2:
|
||||||
|
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
|
||||||
|
|
||||||
|
def _get_batch_count(self, A, B, C, D) -> int:
|
||||||
|
"""
|
||||||
|
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
|
||||||
|
tensors match in batch size. Presence of a batch dimension is detected by one of the
|
||||||
|
tensors being rank 3. If a batch dimension is present, it must be present in one of
|
||||||
|
operands A, B, or C (but need not be in all), and must be present in D.
|
||||||
|
|
||||||
|
:param A: tensor A
|
||||||
|
:type A: numpy/cupy/torch array/tensor object
|
||||||
|
:param B: tensor B
|
||||||
|
:type B: numpy/cupy/torch array/tensor object
|
||||||
|
:param C: tensor C
|
||||||
|
:type C: numpy/cupy/torch array/tensor object
|
||||||
|
:param D: tensor D
|
||||||
|
:type D: numpy/cupy/torch array/tensor object
|
||||||
|
|
||||||
|
:return: tuple of batch count dimensions
|
||||||
|
:rtype: tuple
|
||||||
|
"""
|
||||||
|
A_batch = A.shape[:-2] if len(A.shape) > 2 else tuple()
|
||||||
|
B_batch = B.shape[:-2] if len(B.shape) > 2 else tuple()
|
||||||
|
C_batch = C.shape[:-2] if len(C.shape) > 2 else tuple()
|
||||||
|
D_batch = D.shape[:-2] if len(D.shape) > 2 else tuple()
|
||||||
|
|
||||||
|
if len(D_batch) > 0 and D_batch not in [A_batch, B_batch, C_batch]:
|
||||||
|
raise Exception(f"Batch count in D must be present in one of operands A, B, and C. "
|
||||||
|
f"Batch counts are: A={A_batch}, B={B_batch}, C={C_batch}, D={D_batch}")
|
||||||
|
|
||||||
|
for batch_shape in [A_batch, B_batch, C_batch]:
|
||||||
|
if len(batch_shape) > 0 and batch_shape != D_batch:
|
||||||
|
raise Exception(f"Batch count for all other operands must either match that of D or be zero."
|
||||||
|
f"Received batch shape of {batch_shape}, which does not match that of D of {D_batch}.")
|
||||||
|
|
||||||
|
return D_batch
|
||||||
|
|
||||||
|
def _get_batch_stride(self, tensor) -> int:
|
||||||
|
"""
|
||||||
|
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
|
||||||
|
|
||||||
|
:param tensor: tensor object to process
|
||||||
|
:type tensor: numpy/cupy/torch array/tensor object
|
||||||
|
|
||||||
|
:return: stride between each matrix in the batch
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
if len(tensor.shape) > 2:
|
||||||
|
return tensor.shape[-2] * tensor.shape[-1]
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _get_problem_args(self, A, B, C, D) -> tuple:
|
||||||
|
"""
|
||||||
|
Returns the problem size and GEMM universal mode to use for the
|
||||||
|
given operands.
|
||||||
|
|
||||||
|
:param A: tensor A
|
||||||
|
:type A: numpy/cupy/torch array/tensor object
|
||||||
|
:param B: tensor B
|
||||||
|
:type B: numpy/cupy/torch array/tensor object
|
||||||
|
:param C: tensor C
|
||||||
|
:type C: numpy/cupy/torch array/tensor object
|
||||||
|
:param D: tensor D
|
||||||
|
:type D: numpy/cupy/torch array/tensor object
|
||||||
|
|
||||||
|
:return: tuple containing the problem size (cutlass_bindings.gemm.GemmCoord), the GEMM mode (cutlass_bindings.gemm.Mode), and the batch count (int)
|
||||||
|
:rtype: tuple
|
||||||
|
"""
|
||||||
|
M, K = A.shape[-2:]
|
||||||
|
N = B.shape[-1]
|
||||||
|
mode = cutlass_bindings.gemm.Mode.Gemm
|
||||||
|
|
||||||
|
batch_count = self._get_batch_count(A, B, C, D)
|
||||||
|
returned_batch_count = prod(batch_count) if len(batch_count) > 0 else 1
|
||||||
|
|
||||||
|
# If we are running a batched GEMM in which there is a nonzero batch stride
|
||||||
|
# only for A, then we can fold the batched dimension of A into the M dimension
|
||||||
|
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
|
||||||
|
# and C are row major. A similar operation can be performed if only B has a nonzero
|
||||||
|
# batch dimension
|
||||||
|
if len(batch_count) > 0:
|
||||||
|
A_row = self._layout_a == cutlass.LayoutType.RowMajor
|
||||||
|
B_row = self._layout_b == cutlass.LayoutType.RowMajor
|
||||||
|
C_row = self._layout_c == cutlass.LayoutType.RowMajor
|
||||||
|
|
||||||
|
batched = lambda x : len(x.shape) == 2 + len(batch_count)
|
||||||
|
|
||||||
|
if batched(A) and not batched(B) and batched(C) and A_row and C_row:
|
||||||
|
M *= prod(batch_count)
|
||||||
|
returned_batch_count = 1
|
||||||
|
elif not batched(A) and batched(B) and batched(C) and not B_row and not C_row:
|
||||||
|
N *= prod(batch_count)
|
||||||
|
returned_batch_count = 1
|
||||||
|
else:
|
||||||
|
mode = cutlass_bindings.gemm.Mode.Batched
|
||||||
|
|
||||||
|
return cutlass_bindings.gemm.GemmCoord(M, N, K), mode, returned_batch_count
|
||||||
|
|
||||||
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
||||||
"""
|
"""
|
||||||
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
||||||
@ -461,8 +570,7 @@ class Gemm(OperationBase):
|
|||||||
f'layout of ({ref_type}, {ref_layout}).')
|
f'layout of ({ref_type}, {ref_layout}).')
|
||||||
|
|
||||||
def run(self, A=None, B=None, C=None, D=None,
|
def run(self, A=None, B=None, C=None, D=None,
|
||||||
alpha=None, beta=None, batch_count: int = 1,
|
alpha=None, beta=None, sync: bool = True, print_module: bool = False) -> GemmArguments:
|
||||||
sync: bool = True, print_module: bool = False) -> GemmArguments:
|
|
||||||
"""
|
"""
|
||||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||||
@ -481,8 +589,6 @@ class Gemm(OperationBase):
|
|||||||
:param D: tensor representing data type and layout of operand D
|
:param D: tensor representing data type and layout of operand D
|
||||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||||
:param batch_count: number of GEMMs in the batch
|
|
||||||
:type batch_count: int
|
|
||||||
:param sync: whether the call should wait for the kernel to complete before returning
|
:param sync: whether the call should wait for the kernel to complete before returning
|
||||||
:type sync: bool
|
:type sync: bool
|
||||||
:param print_module: whether to print the emitted C++ code
|
:param print_module: whether to print the emitted C++ code
|
||||||
@ -491,9 +597,6 @@ class Gemm(OperationBase):
|
|||||||
:return: arguments passed in to the kernel
|
:return: arguments passed in to the kernel
|
||||||
:rtype: cutlass.backend.GemmArguments
|
:rtype: cutlass.backend.GemmArguments
|
||||||
"""
|
"""
|
||||||
if batch_count < 1:
|
|
||||||
raise Exception(f"Invalid batch count {batch_count}. Value must be an integer >= 1.")
|
|
||||||
|
|
||||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||||
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
||||||
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
||||||
@ -501,20 +604,31 @@ class Gemm(OperationBase):
|
|||||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||||
|
|
||||||
|
self._verify_rank(A)
|
||||||
|
self._verify_rank(B)
|
||||||
|
self._verify_rank(C)
|
||||||
|
self._verify_rank(D)
|
||||||
|
|
||||||
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a)
|
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a)
|
||||||
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b)
|
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b)
|
||||||
alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c)
|
alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c)
|
||||||
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||||
alignment_C=alignment_c, print_module=print_module)
|
alignment_C=alignment_c, print_module=print_module)
|
||||||
|
|
||||||
problem_size = cutlass_bindings.gemm.GemmCoord(A.shape[0], B.shape[1], A.shape[1])
|
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
|
||||||
|
|
||||||
if batch_count == 1:
|
if mode == cutlass_bindings.gemm.Mode.Gemm or batch_count == 1:
|
||||||
mode = cutlass_bindings.gemm.Mode.Gemm
|
|
||||||
kwargs = {'split_k_slices': 1}
|
kwargs = {'split_k_slices': 1}
|
||||||
else:
|
else:
|
||||||
mode = cutlass_bindings.gemm.Mode.Batched
|
kwargs = {
|
||||||
kwargs = {'batch': batch_count}
|
'batch': batch_count,
|
||||||
|
'batch_strides': {
|
||||||
|
'A': self._get_batch_stride(A),
|
||||||
|
'B': self._get_batch_stride(B),
|
||||||
|
'C': self._get_batch_stride(C),
|
||||||
|
'D': self._get_batch_stride(D)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
arguments = GemmArguments(
|
arguments = GemmArguments(
|
||||||
operation=self.operation, problem_size=problem_size,
|
operation=self.operation, problem_size=problem_size,
|
||||||
|
|||||||
139
test/python/gemm/gemm_batched.py
Normal file
139
test/python/gemm/gemm_batched.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
#################################################################################################
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
#
|
||||||
|
#################################################################################################
|
||||||
|
|
||||||
|
"""
|
||||||
|
High-level tests for running batched GEMMs
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from math import prod
|
||||||
|
|
||||||
|
import cutlass
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from cutlass.backend.test.utils import LayoutCombination, add_test_gemm
|
||||||
|
from cutlass.backend.utils.device import device_cc
|
||||||
|
|
||||||
|
cutlass.set_log_level(logging.WARNING)
|
||||||
|
|
||||||
|
torch.manual_seed(2023)
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_reference(A, B, C, alpha, beta):
|
||||||
|
# Get the batch count. Assume that any of A, B, and C
|
||||||
|
# with a batch dimension ahve matching batch count. Thus,
|
||||||
|
# we break out of the loop once we have found the first
|
||||||
|
# tensor containing a batch dimension.
|
||||||
|
batch_count = (1,)
|
||||||
|
for tensor in [A, B, C]:
|
||||||
|
if len(tensor.shape) > 2:
|
||||||
|
batch_count = tensor.shape[:-2]
|
||||||
|
break
|
||||||
|
|
||||||
|
int_batch_count = prod(batch_count)
|
||||||
|
|
||||||
|
def add_batch(tensor):
|
||||||
|
if len(tensor.shape) == 2:
|
||||||
|
return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1)
|
||||||
|
else:
|
||||||
|
return tensor.reshape(-1, tensor.size(-2), tensor.size(-1))
|
||||||
|
|
||||||
|
# Reshape tensors to have batch dimension
|
||||||
|
A = add_batch(A)
|
||||||
|
B = add_batch(B)
|
||||||
|
C = add_batch(C)
|
||||||
|
|
||||||
|
ret = (torch.bmm(A, B) * alpha) + (C * beta)
|
||||||
|
reshape_vals = batch_count + C.shape[-2:]
|
||||||
|
return ret.reshape(*reshape_vals)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize(rows, cols, batch):
|
||||||
|
tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half()
|
||||||
|
if len(batch) > 0 and prod(batch) > 1:
|
||||||
|
reshape_vals = batch + (rows, cols)
|
||||||
|
return tensor.reshape(*reshape_vals)
|
||||||
|
else:
|
||||||
|
return tensor.reshape(rows, cols)
|
||||||
|
|
||||||
|
|
||||||
|
class GemmF16Batched(unittest.TestCase):
|
||||||
|
def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool):
|
||||||
|
M = 512
|
||||||
|
N = 256
|
||||||
|
K = 128
|
||||||
|
alpha = 1.
|
||||||
|
beta = 2.
|
||||||
|
|
||||||
|
A = initialize(M, K, batch_count if batch_A else (1,))
|
||||||
|
B = initialize(K, N, batch_count if batch_B else (1,))
|
||||||
|
C = initialize(M, N, batch_count if batch_C else (1,))
|
||||||
|
D = initialize(M, N, batch_count)
|
||||||
|
|
||||||
|
plan = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass.DataType.f32)
|
||||||
|
plan.run(A, B, C, D, alpha, beta)
|
||||||
|
reference = pytorch_reference(A, B, C, alpha, beta)
|
||||||
|
assert reference.equal(D)
|
||||||
|
|
||||||
|
def test_batched_ABC(self):
|
||||||
|
self.run_batched((3,), True, True, True)
|
||||||
|
self.run_batched((2, 3), True, True, True)
|
||||||
|
|
||||||
|
def test_batched_AB(self):
|
||||||
|
self.run_batched((3,), True, True, False)
|
||||||
|
self.run_batched((2, 3), True, True, False)
|
||||||
|
|
||||||
|
def test_batched_AC(self):
|
||||||
|
self.run_batched((3,), True, False, True)
|
||||||
|
self.run_batched((2, 3), True, False, True)
|
||||||
|
|
||||||
|
def test_batched_BC(self):
|
||||||
|
self.run_batched((3,), False, True, True)
|
||||||
|
self.run_batched((2, 3), False, True, True)
|
||||||
|
|
||||||
|
def test_batched_A(self):
|
||||||
|
self.run_batched((3,), True, False, False)
|
||||||
|
self.run_batched((2, 3), True, False, False)
|
||||||
|
|
||||||
|
def test_batched_B(self):
|
||||||
|
self.run_batched((3,), False, True, False)
|
||||||
|
self.run_batched((2, 3), False, True, False)
|
||||||
|
|
||||||
|
def test_batched_C(self):
|
||||||
|
self.run_batched((3,), False, False, True)
|
||||||
|
self.run_batched((2, 3), False, False, True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
@ -61,7 +61,7 @@ __global__ void convert(
|
|||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename Destination, typename Source, int Count>
|
template <typename Destination, typename Source, int Count>
|
||||||
void run_test() {
|
void run_test(const char dest_name[], const char source_name[]) {
|
||||||
const int kN = Count;
|
const int kN = Count;
|
||||||
|
|
||||||
dim3 grid(1, 1);
|
dim3 grid(1, 1);
|
||||||
@ -84,7 +84,10 @@ void run_test() {
|
|||||||
destination.sync_host();
|
destination.sync_host();
|
||||||
|
|
||||||
for (int i = 0; i < kN; ++i) {
|
for (int i = 0; i < kN; ++i) {
|
||||||
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
|
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]))
|
||||||
|
<< "Destination type: " << dest_name
|
||||||
|
<< ", Source type: " << source_name
|
||||||
|
<< ", Count: " << Count;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,15 +100,19 @@ void run_test() {
|
|||||||
TEST(NumericConversion, f32_to_f16_rn) {
|
TEST(NumericConversion, f32_to_f16_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = float;
|
using Source = float;
|
||||||
|
const char source_name[] = "float";
|
||||||
using Destination = cutlass::half_t;
|
using Destination = cutlass::half_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "half_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f32x8_to_f16x8_rn) {
|
TEST(NumericConversion, f32x8_to_f16x8_rn) {
|
||||||
int const kN = 8;
|
int const kN = 8;
|
||||||
using Source = float;
|
using Source = float;
|
||||||
|
const char source_name[] = "float";
|
||||||
using Destination = cutlass::half_t;
|
using Destination = cutlass::half_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "half_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -113,15 +120,19 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) {
|
|||||||
TEST(NumericConversion, f16_to_f32_rn) {
|
TEST(NumericConversion, f16_to_f32_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::half_t;
|
using Source = cutlass::half_t;
|
||||||
|
const char source_name[] = "half_t";
|
||||||
using Destination = float;
|
using Destination = float;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f16x8_to_f32x8_rn) {
|
TEST(NumericConversion, f16x8_to_f32x8_rn) {
|
||||||
int const kN = 8;
|
int const kN = 8;
|
||||||
using Source = cutlass::half_t;
|
using Source = cutlass::half_t;
|
||||||
|
const char source_name[] = "half_t";
|
||||||
using Destination = float;
|
using Destination = float;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -129,86 +140,109 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) {
|
|||||||
TEST(NumericConversion, f32_to_fe4m3_rn) {
|
TEST(NumericConversion, f32_to_fe4m3_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = float;
|
using Source = float;
|
||||||
|
const char source_name[] = "float";
|
||||||
using Destination = cutlass::float_e4m3_t;
|
using Destination = cutlass::float_e4m3_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e4m3_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f32_to_fe4m3_rn_array) {
|
TEST(NumericConversion, f32_to_fe4m3_rn_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = float;
|
using Source = float;
|
||||||
|
const char source_name[] = "float";
|
||||||
using Destination = cutlass::float_e4m3_t;
|
using Destination = cutlass::float_e4m3_t;
|
||||||
|
const char dest_name[] = "float_e4m3_t";
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f32_to_fe5m2_rn) {
|
TEST(NumericConversion, f32_to_fe5m2_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = float;
|
using Source = float;
|
||||||
|
const char source_name[] = "float";
|
||||||
using Destination = cutlass::float_e5m2_t;
|
using Destination = cutlass::float_e5m2_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e5m2_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f32_to_fe5m2_rn_array) {
|
TEST(NumericConversion, f32_to_fe5m2_rn_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = float;
|
using Source = float;
|
||||||
|
const char source_name[] = "float";
|
||||||
using Destination = cutlass::float_e5m2_t;
|
using Destination = cutlass::float_e5m2_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e5m2_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f16_to_fe4m3_rn) {
|
TEST(NumericConversion, f16_to_fe4m3_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::half_t;
|
using Source = cutlass::half_t;
|
||||||
|
const char source_name[] = "half_t";
|
||||||
using Destination = cutlass::float_e4m3_t;
|
using Destination = cutlass::float_e4m3_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e4m3_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f16_to_fe4m3_rn_array) {
|
TEST(NumericConversion, f16_to_fe4m3_rn_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::half_t;
|
using Source = cutlass::half_t;
|
||||||
|
const char source_name[] = "half_t";
|
||||||
using Destination = cutlass::float_e4m3_t;
|
using Destination = cutlass::float_e4m3_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e4m3_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f16_to_fe5m2_rn) {
|
TEST(NumericConversion, f16_to_fe5m2_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::half_t;
|
using Source = cutlass::half_t;
|
||||||
|
const char source_name[] = "half_t";
|
||||||
using Destination = cutlass::float_e5m2_t;
|
using Destination = cutlass::float_e5m2_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e5m2_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, f16_to_fe5m2_rn_array) {
|
TEST(NumericConversion, f16_to_fe5m2_rn_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::half_t;
|
using Source = cutlass::half_t;
|
||||||
|
const char source_name[] = "half_t";
|
||||||
using Destination = cutlass::float_e5m2_t;
|
using Destination = cutlass::float_e5m2_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e5m2_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, bf16_to_fe4m3_rn) {
|
TEST(NumericConversion, bf16_to_fe4m3_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::bfloat16_t;
|
using Source = cutlass::bfloat16_t;
|
||||||
|
const char source_name[] = "bfloat16_t";
|
||||||
using Destination = cutlass::float_e4m3_t;
|
using Destination = cutlass::float_e4m3_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e4m3_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, bf16_to_fe4m3_rn_array) {
|
TEST(NumericConversion, bf16_to_fe4m3_rn_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::bfloat16_t;
|
using Source = cutlass::bfloat16_t;
|
||||||
|
const char source_name[] = "bfloat16_t";
|
||||||
using Destination = cutlass::float_e4m3_t;
|
using Destination = cutlass::float_e4m3_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e4m3_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, bf16_to_fe5m2_rn) {
|
TEST(NumericConversion, bf16_to_fe5m2_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::bfloat16_t;
|
using Source = cutlass::bfloat16_t;
|
||||||
|
const char source_name[] = "bfloat16_t";
|
||||||
using Destination = cutlass::float_e5m2_t;
|
using Destination = cutlass::float_e5m2_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e5m2_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
|
TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::bfloat16_t;
|
using Source = cutlass::bfloat16_t;
|
||||||
|
const char source_name[] = "bfloat16_t";
|
||||||
using Destination = cutlass::float_e5m2_t;
|
using Destination = cutlass::float_e5m2_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e5m2_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -216,36 +250,46 @@ TEST(NumericConversion, bf16_to_fe5m2_rn_array) {
|
|||||||
TEST(NumericConversion, fe4m3_to_fe5m2_rn) {
|
TEST(NumericConversion, fe4m3_to_fe5m2_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::float_e4m3_t;
|
using Source = cutlass::float_e4m3_t;
|
||||||
|
const char source_name[] = "float_e4m3_t";
|
||||||
using Destination = cutlass::float_e5m2_t;
|
using Destination = cutlass::float_e5m2_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e5m2_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe4m3_to_fe5m2_array) {
|
TEST(NumericConversion, fe4m3_to_fe5m2_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::float_e4m3_t;
|
using Source = cutlass::float_e4m3_t;
|
||||||
|
const char source_name[] = "float_e4m3_t";
|
||||||
using Destination = cutlass::float_e5m2_t;
|
using Destination = cutlass::float_e5m2_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e5m2_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe5m2_to_fe4m3_rn) {
|
TEST(NumericConversion, fe5m2_to_fe4m3_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::float_e5m2_t;
|
using Source = cutlass::float_e5m2_t;
|
||||||
|
const char source_name[] = "float_e5m2_t";
|
||||||
using Destination = cutlass::float_e4m3_t;
|
using Destination = cutlass::float_e4m3_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e4m3_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe5m2_to_fe4m3_array) {
|
TEST(NumericConversion, fe5m2_to_fe4m3_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::float_e5m2_t;
|
using Source = cutlass::float_e5m2_t;
|
||||||
|
const char source_name[] = "float_e5m2_t";
|
||||||
using Destination = cutlass::float_e4m3_t;
|
using Destination = cutlass::float_e4m3_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float_e4m3_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe4m3_to_f32_rn) {
|
TEST(NumericConversion, fe4m3_to_f32_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::float_e4m3_t;
|
using Source = cutlass::float_e4m3_t;
|
||||||
|
const char source_name[] = "float_e4m3_t";
|
||||||
using Destination = float;
|
using Destination = float;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -254,78 +298,100 @@ TEST(NumericConversion, f32x8_to_s8x8_rn) {
|
|||||||
|
|
||||||
int const kN = 8;
|
int const kN = 8;
|
||||||
using Source = float;
|
using Source = float;
|
||||||
|
const char source_name[] = "float";
|
||||||
using Destination = int8_t;
|
using Destination = int8_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "int8_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe4m3_to_f32_array) {
|
TEST(NumericConversion, fe4m3_to_f32_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::float_e4m3_t;
|
using Source = cutlass::float_e4m3_t;
|
||||||
|
const char source_name[] = "float_e4m3_t";
|
||||||
using Destination = float;
|
using Destination = float;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe5m2_to_f32_array) {
|
TEST(NumericConversion, fe5m2_to_f32_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::float_e5m2_t;
|
using Source = cutlass::float_e5m2_t;
|
||||||
|
const char source_name[] = "float_e5m2_t";
|
||||||
using Destination = float;
|
using Destination = float;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "float";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe4m3_to_f16_rn) {
|
TEST(NumericConversion, fe4m3_to_f16_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::float_e4m3_t;
|
using Source = cutlass::float_e4m3_t;
|
||||||
|
const char source_name[] = "float_e4m3_t";
|
||||||
using Destination = cutlass::half_t;
|
using Destination = cutlass::half_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "half_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe4m3_to_f16_array) {
|
TEST(NumericConversion, fe4m3_to_f16_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::float_e4m3_t;
|
using Source = cutlass::float_e4m3_t;
|
||||||
|
const char source_name[] = "float_e4m3_t";
|
||||||
using Destination = cutlass::half_t;
|
using Destination = cutlass::half_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "half_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe5m2_to_f16_rn) {
|
TEST(NumericConversion, fe5m2_to_f16_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::float_e5m2_t;
|
using Source = cutlass::float_e5m2_t;
|
||||||
|
const char source_name[] = "float_e5m2_t";
|
||||||
using Destination = cutlass::half_t;
|
using Destination = cutlass::half_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "half_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe5m2_to_f16_array) {
|
TEST(NumericConversion, fe5m2_to_f16_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::float_e5m2_t;
|
using Source = cutlass::float_e5m2_t;
|
||||||
|
const char source_name[] = "float_e5m2_t";
|
||||||
using Destination = cutlass::half_t;
|
using Destination = cutlass::half_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "half_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe4m3_to_bf16_rn) {
|
TEST(NumericConversion, fe4m3_to_bf16_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::float_e4m3_t;
|
using Source = cutlass::float_e4m3_t;
|
||||||
|
const char source_name[] = "float_e4m3_t";
|
||||||
using Destination = cutlass::bfloat16_t;
|
using Destination = cutlass::bfloat16_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "bfloat16_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe4m3_to_bf16_array) {
|
TEST(NumericConversion, fe4m3_to_bf16_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::float_e4m3_t;
|
using Source = cutlass::float_e4m3_t;
|
||||||
|
const char source_name[] = "float_e4m3_t";
|
||||||
using Destination = cutlass::bfloat16_t;
|
using Destination = cutlass::bfloat16_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "bfloat16_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe5m2_to_bf16_rn) {
|
TEST(NumericConversion, fe5m2_to_bf16_rn) {
|
||||||
int const kN = 1;
|
int const kN = 1;
|
||||||
using Source = cutlass::float_e5m2_t;
|
using Source = cutlass::float_e5m2_t;
|
||||||
|
const char source_name[] = "float_e5m2_t";
|
||||||
using Destination = cutlass::bfloat16_t;
|
using Destination = cutlass::bfloat16_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "bfloat16_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NumericConversion, fe5m2_to_bf16_array) {
|
TEST(NumericConversion, fe5m2_to_bf16_array) {
|
||||||
int const kN = 27;
|
int const kN = 27;
|
||||||
using Source = cutlass::float_e5m2_t;
|
using Source = cutlass::float_e5m2_t;
|
||||||
|
const char source_name[] = "float_e5m2_t";
|
||||||
using Destination = cutlass::bfloat16_t;
|
using Destination = cutlass::bfloat16_t;
|
||||||
test::core::kernel::run_test<Destination, Source, kN>();
|
const char dest_name[] = "bfloat16_t";
|
||||||
|
test::core::kernel::run_test<Destination, Source, kN>(dest_name, source_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -36,6 +36,7 @@ cutlass_test_unit_add_executable(
|
|||||||
compare.cpp
|
compare.cpp
|
||||||
complement.cpp
|
complement.cpp
|
||||||
composition.cpp
|
composition.cpp
|
||||||
|
constant_arithmetic.cpp
|
||||||
core_unit.cpp
|
core_unit.cpp
|
||||||
inverse_left.cpp
|
inverse_left.cpp
|
||||||
inverse_right.cpp
|
inverse_right.cpp
|
||||||
|
|||||||
106
test/unit/cute/core/constant_arithmetic.cpp
Normal file
106
test/unit/cute/core/constant_arithmetic.cpp
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#include "cutlass_unit_test.h"
|
||||||
|
#include <cutlass/trace.h>
|
||||||
|
#include <cute/swizzle.hpp>
|
||||||
|
|
||||||
|
TEST(CuTe_core, ConstantArithmetic) {
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
constexpr cute::integral_constant<uint32_t, 0> uzero{};
|
||||||
|
|
||||||
|
// This extra test exists historically as part of the diagnosis
|
||||||
|
// of a possible Clang 14 bug. However, it's a nice test for
|
||||||
|
// cute::integral_constant's arithmetic operators, so it's saved here.
|
||||||
|
// It also demonstrates how to work with cute::integral_constant
|
||||||
|
// and lambda captures. Microsoft Visual Studio ("MSVC") tends to
|
||||||
|
// disagree with other compilers about the meaning of decltype
|
||||||
|
// for variables captured by reference. MSVC and GCC 8.3.0
|
||||||
|
// also tend to disagree with other compilers (and other GCC versions)
|
||||||
|
// about whether expressions involving such variables
|
||||||
|
// are constant expressions.
|
||||||
|
//
|
||||||
|
// A typical CuTe idiom is to do lambda captures by reference [&].
|
||||||
|
// This test changes them to capture by value, except for
|
||||||
|
// the innermost lambda's capture of S1, which is by reference.
|
||||||
|
// The point is to show that MSVC and GCC 8 have issues with this
|
||||||
|
// that other compilers do not. For example,
|
||||||
|
//
|
||||||
|
// 1. MSVC needs remove_cvref_t around decltype(S1)
|
||||||
|
// in order to access decltype(S1)::value, and
|
||||||
|
// 2. MSVC and GCC 8.3.0 both report a build error with S1()
|
||||||
|
// (that is, calling operator() on S1, which returns the
|
||||||
|
// same thing as S1.value).
|
||||||
|
//
|
||||||
|
// The reason for (2) is that neither compiler thinks
|
||||||
|
// that S1() is a constant expression.
|
||||||
|
//
|
||||||
|
// This leaves S1.value as the most concise portable expression
|
||||||
|
// for the "value" member of a cute::integral_constant.
|
||||||
|
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero](auto S0) {
|
||||||
|
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0](auto F0) {
|
||||||
|
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0](auto S1) {
|
||||||
|
for_each(make_integer_sequence<uint32_t, 8>{}, [uzero,S0,F0,&S1](auto F1) {
|
||||||
|
static_assert((decltype(S0)::value & decltype(F0)::value) == decltype(S0 & F0)::value);
|
||||||
|
|
||||||
|
// Using S1.value means you don't have to use remove_cvref_t
|
||||||
|
// with a captured-by-reference variable.
|
||||||
|
static_assert((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||||
|
static_assert((S1.value & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||||
|
// S1() _should_ work, but does not with Visual Studio 2022,
|
||||||
|
// which emits C2131 ("expression did not evaluate to a constant").
|
||||||
|
// It also does not with GCC 8.3.0, which emits an error with messages
|
||||||
|
// "non-constant condition for static assertion" and
|
||||||
|
// "'this' is not a constant expression."
|
||||||
|
//
|
||||||
|
//static_assert((S1() & decltype(F1)::value) == decltype(S1 & F1)::value);
|
||||||
|
static_assert(decltype((S0 & F0) != uzero)::value == ((decltype(S0)::value & decltype(F0)::value) != 0));
|
||||||
|
|
||||||
|
static_assert(decltype((S1 & F1) != uzero)::value == ((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0));
|
||||||
|
static_assert(decltype((S1 & F1) != uzero)::value == ((S1.value & decltype(F1)::value) != 0));
|
||||||
|
|
||||||
|
constexpr bool left = decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value;
|
||||||
|
constexpr bool right =
|
||||||
|
((decltype(S0)::value & decltype(F0)::value) != 0) ||
|
||||||
|
((cute::remove_cvref_t<decltype(S1)>::value & decltype(F1)::value) != 0);
|
||||||
|
constexpr bool right2 =
|
||||||
|
((decltype(S0)::value & decltype(F0)::value) != 0) ||
|
||||||
|
((S1.value & decltype(F1)::value) != 0);
|
||||||
|
static_assert(right == right2);
|
||||||
|
static_assert(left == right);
|
||||||
|
constexpr bool left2 = decltype((S0 & F0) != uzero)::value || decltype((S1 & F1) != uzero)::value;
|
||||||
|
static_assert(left == left2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
@ -31,9 +31,58 @@
|
|||||||
|
|
||||||
#include "cutlass_unit_test.h"
|
#include "cutlass_unit_test.h"
|
||||||
|
|
||||||
|
// C<uint32_t(something)>::value_type is not uint32_t for GCC 7.5.0.
|
||||||
|
// This test is thus disabled for GCC < 8.
|
||||||
|
#if defined(__GNUC__) && (__GNUC__ < 8)
|
||||||
|
|
||||||
#include <cutlass/trace.h>
|
#include <cutlass/trace.h>
|
||||||
#include <cute/swizzle.hpp>
|
#include <cute/swizzle.hpp>
|
||||||
|
|
||||||
|
namespace { // (anonymous)
|
||||||
|
|
||||||
|
// This function exists to work around a Clang 14 issue, in which
|
||||||
|
// the compiler tries to instantiate code that lives inside the
|
||||||
|
// "else" branch of an "if constexpr," even when the "else" branch
|
||||||
|
// is false. That triggers a spurious static_assert in MixedBits.
|
||||||
|
// The work-around is to make the body of the "else" branch a
|
||||||
|
// function, rather than leaving it in line.
|
||||||
|
//
|
||||||
|
// Some compilers strangely deduce the first two terms of
|
||||||
|
// make_integer_sequence<uint32_t, 8> as C<false> and C<true>, and
|
||||||
|
// the remaining terms as C<2>, C<3>, etc. Making this function take
|
||||||
|
// cute::integral_constant<uint32_t, S0_value>, etc. doesn't work
|
||||||
|
// with those compilers.
|
||||||
|
template<class S0_type, S0_type S0_value,
|
||||||
|
class F0_type, F0_type F0_value,
|
||||||
|
class S1_type, S1_type S1_value,
|
||||||
|
class F1_type, F1_type F1_value>
|
||||||
|
void clang14_workaround(cute::integral_constant<S0_type, S0_value>,
|
||||||
|
cute::integral_constant<F0_type, F0_value>,
|
||||||
|
cute::integral_constant<S1_type, S1_value>,
|
||||||
|
cute::integral_constant<F1_type, F1_value>)
|
||||||
|
{
|
||||||
|
constexpr cute::C<static_cast<uint32_t>(S0_value)> S0{};
|
||||||
|
constexpr cute::C<static_cast<uint32_t>(F0_value)> F0{};
|
||||||
|
constexpr cute::C<static_cast<uint32_t>(S1_value)> S1{};
|
||||||
|
constexpr cute::C<static_cast<uint32_t>(F1_value)> F1{};
|
||||||
|
|
||||||
|
for (uint32_t d0 = 0; d0 < 8; ++d0) {
|
||||||
|
if ((d0 & F0) != d0) { continue; } // Skip repeats
|
||||||
|
for (uint32_t d1 = 0; d1 < 8; ++d1) {
|
||||||
|
if ((d1 & F1) != d1) { continue; } // Skip repeats
|
||||||
|
auto m0 = make_mixed_bits(S0, d0, F0);
|
||||||
|
auto m1 = make_mixed_bits(S1, d1, F1);
|
||||||
|
//print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n");
|
||||||
|
EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1));
|
||||||
|
//print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n");
|
||||||
|
EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1));
|
||||||
|
//print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n");
|
||||||
|
EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace (anonymous)
|
||||||
|
|
||||||
TEST(CuTe_core, MixedBits) {
|
TEST(CuTe_core, MixedBits) {
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
@ -48,23 +97,21 @@ TEST(CuTe_core, MixedBits) {
|
|||||||
} else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) {
|
} else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
for (uint32_t d0 = 0; d0 < 8; ++d0) {
|
clang14_workaround(S0, F0, S1, F1);
|
||||||
if ((d0 & F0) != d0) { continue; } // Skip repeats
|
|
||||||
for (uint32_t d1 = 0; d1 < 8; ++d1) {
|
|
||||||
if ((d1 & F1) != d1) { continue; } // Skip repeats
|
|
||||||
auto m0 = make_mixed_bits(S0, d0, F0);
|
|
||||||
auto m1 = make_mixed_bits(S1, d1, F1);
|
|
||||||
//print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n");
|
|
||||||
EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1));
|
|
||||||
//print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n");
|
|
||||||
EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1));
|
|
||||||
//print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n");
|
|
||||||
EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CuTe_core, MakeIntegerSequence) {
|
||||||
|
cute::for_each(cute::make_integer_sequence<uint32_t, 8>{}, [](auto c) {
|
||||||
|
using c_type = decltype(c);
|
||||||
|
constexpr auto c_value = c_type::value;
|
||||||
|
using expected_type = cute::integral_constant<uint32_t, c_value>;
|
||||||
|
static_assert(cute::is_same_v<c_type, expected_type>);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // defined(__GNUC__) && (__GNUC__ < 8)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
* SPDX-License-Identifier: BSD-3-Clause
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
* SPDX-License-Identifier: BSD-3-Clause
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
|||||||
@ -32,6 +32,7 @@
|
|||||||
\brief Tests that the stream-K scheduler covers the entire problem space.
|
\brief Tests that the stream-K scheduler covers the entire problem space.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cluster_launch.hpp"
|
||||||
#include "cutlass/kernel_hardware_info.hpp"
|
#include "cutlass/kernel_hardware_info.hpp"
|
||||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp"
|
#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp"
|
||||||
#include "cutlass/util/device_memory.h"
|
#include "cutlass/util/device_memory.h"
|
||||||
@ -39,6 +40,10 @@
|
|||||||
|
|
||||||
#include "../../common/cutlass_unit_test.h"
|
#include "../../common/cutlass_unit_test.h"
|
||||||
|
|
||||||
|
// Grids are launched with clusters enabled in these tests,
|
||||||
|
// so the CTK version must support cluster launching.
|
||||||
|
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
using ProblemShape_MNKL = Shape<int, int, int, int>;
|
using ProblemShape_MNKL = Shape<int, int, int, int>;
|
||||||
|
|
||||||
@ -60,7 +65,7 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape
|
|||||||
while (work_tile_info.is_valid_tile) {
|
while (work_tile_info.is_valid_tile) {
|
||||||
// Increment counters to indicate coverage
|
// Increment counters to indicate coverage
|
||||||
auto tile_idx = Scheduler::output_tile_index(params, work_tile_info);
|
auto tile_idx = Scheduler::output_tile_index(params, work_tile_info);
|
||||||
auto offset = tile_idx * params.k_iter_per_tile_ + work_tile_info.K_idx;
|
auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx;
|
||||||
for (auto i = 0; i < work_tile_info.k_tile_count; ++i) {
|
for (auto i = 0; i < work_tile_info.k_tile_count; ++i) {
|
||||||
// Use atomicAdd because the visit counters are shared by multiple thread blocks.
|
// Use atomicAdd because the visit counters are shared by multiple thread blocks.
|
||||||
// While having more than one block increment the same counter indicates failure,
|
// While having more than one block increment the same counter indicates failure,
|
||||||
@ -103,7 +108,7 @@ test_scheduler(
|
|||||||
|
|
||||||
// Allocate counters indicating the number of times each k iteration of each output tile has been visited
|
// Allocate counters indicating the number of times each k iteration of each output tile has been visited
|
||||||
auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
|
auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
|
||||||
auto total_counters = blk_m * blk_n * blk_l * params.k_iter_per_tile_;
|
auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_;
|
||||||
cutlass::DeviceAllocation<int> visit_counters(total_counters);
|
cutlass::DeviceAllocation<int> visit_counters(total_counters);
|
||||||
|
|
||||||
// Initialize counters to zero
|
// Initialize counters to zero
|
||||||
@ -118,12 +123,55 @@ test_scheduler(
|
|||||||
// Set up the grid for the problem
|
// Set up the grid for the problem
|
||||||
dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args);
|
dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args);
|
||||||
|
|
||||||
|
// Set up cluster and cluster launch. This is needed even for this simple kernel because
|
||||||
|
// the SM90 scheduler needs to be able to query the CTA id within a cluster, which requires
|
||||||
|
// explicitly launching with clusters.
|
||||||
|
dim3 cluster{
|
||||||
|
static_cast<uint32_t>(cute::get<0>(ClusterShape{})),
|
||||||
|
static_cast<uint32_t>(cute::get<1>(ClusterShape{})),
|
||||||
|
static_cast<uint32_t>(cute::get<2>(ClusterShape{}))
|
||||||
|
};
|
||||||
|
|
||||||
|
cudaLaunchConfig_t launch_config;
|
||||||
|
launch_config.gridDim = grid;
|
||||||
|
launch_config.blockDim = {1, 1, 1};
|
||||||
|
launch_config.dynamicSmemBytes = 0;
|
||||||
|
launch_config.stream = NULL;
|
||||||
|
|
||||||
|
cudaLaunchAttribute launch_attribute[1];
|
||||||
|
launch_attribute[0].id = cudaLaunchAttributeClusterDimension;
|
||||||
|
launch_attribute[0].val.clusterDim.x = cluster.x;
|
||||||
|
launch_attribute[0].val.clusterDim.y = cluster.y;
|
||||||
|
launch_attribute[0].val.clusterDim.z = cluster.z;
|
||||||
|
|
||||||
|
launch_config.attrs = launch_attribute;
|
||||||
|
launch_config.numAttrs = 1;
|
||||||
|
|
||||||
|
void const* kernel = (void const*) run_scheduler<Scheduler, TileShape, ClusterShape>;
|
||||||
|
int* counters_ptr = visit_counters.get();
|
||||||
|
void* kernel_params[] = {
|
||||||
|
&counters_ptr,
|
||||||
|
¶ms,
|
||||||
|
&tile_shape,
|
||||||
|
&cluster_shape,
|
||||||
|
&problem_shape_mnkl
|
||||||
|
};
|
||||||
|
|
||||||
// Run the scheduler to completion and log visits to each k iteration
|
// Run the scheduler to completion and log visits to each k iteration
|
||||||
run_scheduler<Scheduler, TileShape, ClusterShape><<<grid, 1>>>(
|
err = cudaLaunchKernelExC(&launch_config, kernel, kernel_params);
|
||||||
visit_counters.get(), params, tile_shape, cluster_shape, problem_shape_mnkl);
|
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
std::cerr << __FILE__ << ":" << __LINE__
|
||||||
|
<< " cudaLaunchKernelExC failed with error: "
|
||||||
|
<< cudaGetErrorString(err) << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
err = cudaDeviceSynchronize();
|
err = cudaDeviceSynchronize();
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
std::cerr << __FILE__ << ":" << __LINE__ << " scheduler kernel failed with error: " << cudaGetErrorString(err) << std::endl;
|
std::cerr << __FILE__ << ":" << __LINE__
|
||||||
|
<< " scheduler kernel failed with error: "
|
||||||
|
<< cudaGetErrorString(err) << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,11 +191,11 @@ test_scheduler(
|
|||||||
<< " and grid size " << grid.x << "x"
|
<< " and grid size " << grid.x << "x"
|
||||||
<< grid.y << "x" << grid.z
|
<< grid.y << "x" << grid.z
|
||||||
<< " splits=" << params.splits_
|
<< " splits=" << params.splits_
|
||||||
<< " k_iter=" << params.k_iter_per_tile_
|
<< " k_iter=" << params.k_tiles_per_output_tile_
|
||||||
<< " big_units=" << params.big_units_
|
<< " big_units=" << params.big_units_
|
||||||
<< " sk_tiles=" << params.sk_tiles_
|
<< " sk_tiles=" << params.sk_tiles_
|
||||||
<< " sk_units=" << params.sk_units_
|
<< " sk_units=" << params.sk_units_
|
||||||
<< " k_iter_per_sk_unit=" << params.k_iter_per_sk_unit_ << std::endl;
|
<< " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl;
|
||||||
std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl;
|
std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -274,4 +322,6 @@ TEST(SM90_Device_Gemm_stream_k_scheduler, 128x128x64_2x1x1) {
|
|||||||
EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114));
|
EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif // defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@ -179,7 +179,7 @@ class GemmOperation:
|
|||||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||||
if self.arch >= 90:
|
if self.arch >= 90:
|
||||||
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}{t}"
|
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{t}{k}{e}"
|
||||||
return kernel_name_template.format(
|
return kernel_name_template.format(
|
||||||
p = self.prefix,
|
p = self.prefix,
|
||||||
ar = self.arch,
|
ar = self.arch,
|
||||||
@ -194,9 +194,9 @@ class GemmOperation:
|
|||||||
l = self.tile_description.stages,
|
l = self.tile_description.stages,
|
||||||
s = self.layout_name_3x(),
|
s = self.layout_name_3x(),
|
||||||
al = str(max(self.A.alignment, self.B.alignment)),
|
al = str(max(self.A.alignment, self.B.alignment)),
|
||||||
|
t = TileSchedulerSuffixes[self.tile_scheduler],
|
||||||
k = self.kernel_schedule_name_3x(),
|
k = self.kernel_schedule_name_3x(),
|
||||||
e = self.epilogue_schedule_name_3x(),
|
e = self.epilogue_schedule_name_3x())
|
||||||
t = TileSchedulerSuffixes[self.tile_scheduler])
|
|
||||||
else:
|
else:
|
||||||
threadblock = self.tile_description.procedural_name()
|
threadblock = self.tile_description.procedural_name()
|
||||||
return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
|
return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
|
||||||
@ -661,8 +661,7 @@ using ${operation_name}_mainloop =
|
|||||||
${element_accumulator},
|
${element_accumulator},
|
||||||
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
|
||||||
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
||||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
${stages},
|
||||||
sizeof(typename ${operation_name}_epilogue::SharedStorage)>,
|
|
||||||
${kernel_schedule}
|
${kernel_schedule}
|
||||||
>::CollectiveOp;
|
>::CollectiveOp;
|
||||||
|
|
||||||
@ -697,7 +696,7 @@ ${compile_guard_end}
|
|||||||
if operation.tile_description.stages > 0:
|
if operation.tile_description.stages > 0:
|
||||||
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
|
stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
|
||||||
else:
|
else:
|
||||||
stage_count_string = "cutlass::gemm::collective::StageCountAuto"
|
stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage)>"
|
||||||
warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)]
|
warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)]
|
||||||
|
|
||||||
instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \
|
instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \
|
||||||
|
|||||||
@ -4218,6 +4218,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
|||||||
layout[2][1] = 8
|
layout[2][1] = 8
|
||||||
|
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules)
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules)
|
||||||
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
|
||||||
|
|
||||||
# persistent kernels with TMA epilogues
|
# persistent kernels with TMA epilogues
|
||||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed,
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user