diff --git a/PUBLICATIONS.md b/PUBLICATIONS.md index 7fc2e359..07324a0a 100644 --- a/PUBLICATIONS.md +++ b/PUBLICATIONS.md @@ -2,10 +2,14 @@ ## 2023 -- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, and Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023. +- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023. - ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023. +- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023. + +- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023. + - ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023. ## 2022 diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp b/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp index e465d43f..d338a31f 100644 --- a/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp +++ b/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp @@ -71,7 +71,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "52_fp8_hopper_warp_specialized_gemm\n\n" + out << "54_fp8_hopper_warp_specialized_gemm\n\n" << " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement\n\n" @@ -93,7 +93,7 @@ struct Options { out << "\n\nExamples:\n\n" - << "$ " << "52_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + << "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; return out; } diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index a8a26c50..52500482 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -54,6 +54,13 @@ struct SyncthreadsSync { } }; +struct SyncwarpSync { + CUTLASS_DEVICE + static void sync() { + __syncwarp(); + } +}; + template < int ThreadCount, int BarrierId @@ -311,6 +318,60 @@ private: } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads) + * via an API that mirrors that of NamedBarrierManager + * + * @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization +**/ +template < + class Synchronizer, + uint32_t ThreadCount_ +> +struct SyncManager { + + // Number of threads participating in the barrier + static constexpr uint32_t ThreadCount = ThreadCount_; + + using BarrierSync = cutlass::GenericBarrier; + + // Underlying type used by all barriers for synchronization. + using T = typename BarrierSync::T; + + CUTLASS_DEVICE + static + void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) { + BarrierSync::wait_lt_helper(lock_ptr, thread_idx, flag_idx, count); + } + + CUTLASS_DEVICE + static void + wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { + BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { + BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index e94a025a..072da899 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -67,7 +67,7 @@ sm90_get_tma_dispatch_policy() { constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v ? 256 : 128); - constexpr int ReuseSmemC = sizeof_bits_v == sizeof_bits_v; + constexpr int ReuseSmemC = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); constexpr int StagesD = 2; constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles; @@ -98,7 +98,7 @@ sm90_get_epilogue_smem_swizzle_layout_atom() { } // Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. -template +template constexpr auto sm90_compute_tile_shape_or_override() { if constexpr (cute::is_same_v) { @@ -107,7 +107,12 @@ sm90_compute_tile_shape_or_override() { return Shape<_128,_32>{}; } else if constexpr (detail::sm90_is_warp_specialized_v) { - return Shape<_64,_32>{}; + if constexpr (sizeof_bits_v == 8) { + return Shape<_64,_64>{}; + } + else { + return Shape<_64,_32>{}; + } } else { static_assert(cutlass::detail::dependent_false, "Unsupported schedule."); diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index 3b78f15b..d8149692 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -34,6 +34,7 @@ #include "cutlass/kernel_hardware_info.hpp" #include "cute/layout.hpp" #include "cute/tensor.hpp" +#include "cute/arch/cluster_sm90.hpp" namespace cutlass::gemm::kernel::detail { @@ -205,18 +206,14 @@ public: uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); - // MSVC requires protecting use of CUDA-specific nonstandard syntax, - // like blockIdx and gridDim, with __CUDA_ARCH__. -#if defined(__CUDA_ARCH__) + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); if (raster_order == RasterOrder::AlongN) { - cluster_minor_offset = blockIdx.x; + cluster_minor_offset = cta_m_in_cluster; } else { - cluster_minor_offset = blockIdx.y; + cluster_minor_offset = cta_n_in_cluster; } -#else - CUTLASS_ASSERT(false && "This line should never be reached"); -#endif uint64_t cluster_idx_minor, cluster_idx_major; diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index 4049255e..ff366c32 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -141,7 +141,7 @@ public: uint32_t splits_ = 1; // Number of tiled k iterations required to compute a single output tile. - uint32_t k_iter_per_tile_ = 0; + uint32_t k_tiles_per_output_tile_ = 0; // Number of stream-K or split-K work units that compute an extra k iteration. // This is done to handle residuals in dividing up the k iteration space. @@ -160,7 +160,7 @@ public: // Number of tiled k iterations computed by each stream-K work unit. This // can potentially cover more than one output tile. - uint32_t k_iter_per_sk_unit_ = 0; + uint32_t k_tiles_per_sk_unit_ = 0; }; // Sink scheduler params as a member @@ -189,9 +189,9 @@ public: uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; - // Number of k iterations each tile computes (this is just the number of k iterations - // in the problem's K dimension) - uint32_t k_iter_per_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) / cute::size<2>(tile_shape); + // Number of k tile iterations in each output tile + uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) / + cute::size<2>(tile_shape); UnderlyingArguments underlying_args; underlying_args.max_swizzle_size = 1; @@ -216,11 +216,11 @@ public: // splits is almost certainly nonnegative here (e.g., hw_info.sm_count, // despite being an int, is a count), so it can safely be converted to unsigned // in the comparison to avoid a signed-unsigned comparison warning-as-error. - splits = static_cast(splits) > k_iter_per_tile ? k_iter_per_tile : splits; + splits = static_cast(splits) > k_tiles_per_output_tile ? k_tiles_per_output_tile : splits; return get_params_basic( underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, - splits, k_iter_per_tile, reduction_workspace); + splits, k_tiles_per_output_tile, reduction_workspace); } // Calculate the maximum number of blocks from clusters of shape cluster_shape that we @@ -229,7 +229,7 @@ public: uint64_t ctas_per_wave = grid.x * grid.y; // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. - uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave); + uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile); uint64_t dp_tiles = output_tiles - sk_tiles; // Calculate the number of work units covering the data-parallel and stream-K tiles. @@ -243,7 +243,7 @@ public: uint64_t dp_units = dp_tiles; // Number of k iterations computed by the stream-K units as a whole - uint64_t k_iter_sk_total = k_iter_per_tile * sk_tiles; + uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles; // If there are stream-K tiles to compute and a sufficiently large number of k iterations // across them, they will be covered by a single wave of persistent threadblocks. Thus, there @@ -255,7 +255,7 @@ public: // Calculate the number of stream-K units that would be needed if each stream-K unit // computed the minimum allowable k iterations. Truncate this to be in units of clusters. - uint64_t min_sized_sk_units = (k_iter_sk_total / min_iters_per_sk_unit_); + uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_); min_sized_sk_units = (min_sized_sk_units / cute::size(cluster_shape)) * cute::size(cluster_shape); uint64_t sk_units = min(ctas_per_wave, min_sized_sk_units); @@ -264,7 +264,7 @@ public: // Short circuit to basic data-parallel decomposition return get_params_basic( underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, - 1, k_iter_per_tile, reduction_workspace); + 1, k_tiles_per_output_tile, reduction_workspace); } // If the number of stream-K units is a multiple of the number of stream-K tiles, then @@ -274,24 +274,24 @@ public: uint32_t sk_splits = static_cast(sk_units / sk_tiles); return get_params_basic( underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, - sk_splits, k_iter_per_tile, reduction_workspace); + sk_splits, k_tiles_per_output_tile, reduction_workspace); } // Number of k iterations computed per stream-K units - uint64_t k_iter_per_sk_unit = k_iter_sk_total / sk_units; + uint64_t k_tiles_per_sk_unit = k_tiles_sk_total / sk_units; // Number of stream-K units that need to compute extra iterations in order to cover // the residual k iterations. This assumes that each such unit computes one additional // iteration. - uint64_t sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units); + uint64_t sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units); // The division below is guaranteed to be exact because sk_big_units is guaranteed // to be a multiple of cluster_size (cute::size(cluster_shape)). This is useful because // it allows us to use a block's linearized cluster ID to determine whether it is // a big block. The reasoning behind this guarnatee is explained as follows: - // sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units); + // sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units); // - // - k_iter_sk_total is a multiple of cluster_size because it is the product + // - k_tiles_sk_total is a multiple of cluster_size because it is the product // of number of tail tiles and the number of k iterations per tile. Because // both the number of output tiles and number of available SMs are rounded // to be multiples of cluster shape, the number of tail tiles @@ -313,12 +313,12 @@ public: underlying_params.raster_order_, cluster_shape, 1, // Static k-splitting factor. Unused for stream-K. - k_iter_per_tile, + k_tiles_per_output_tile, static_cast(sk_big_units_per_cluster), reduction_workspace, sk_tiles, static_cast(sk_units), - static_cast(k_iter_per_sk_unit) + static_cast(k_tiles_per_sk_unit) }; } @@ -338,105 +338,32 @@ public: CUTLASS_DEVICE WorkTileInfo get_current_work() const { - return get_current_work_for_linear_idx(current_work_linear_idx_); + return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params); } CUTLASS_DEVICE - WorkTileInfo - get_current_work_for_linear_idx(uint64_t linear_idx) const { - if (linear_idx >= scheduler_params.units_per_problem_) { + static WorkTileInfo + get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) { + if (linear_idx >= params.units_per_problem_) { // Invalid work. Return an empty result. return {0, 0, 0, 0, false, 0}; } // Determine whether this work unit is a data-parallel or stream-K work unit - bool is_stream_k_unit = linear_idx < scheduler_params.sk_units_; + bool is_stream_k_unit = linear_idx < params.sk_units_; - bool is_split_k = scheduler_params.splits_ > 1; + bool is_split_k = params.splits_ > 1; - // Bypass the stream-K scheduling logic for basic data-parallel or split-K work if (is_split_k || !is_stream_k_unit) { - // The linearized ID space is in terms of work units, rather than tiles. However, - // to compute the correct block offset for a data-parallel tile, we must convert - // the current ID to the data-parallel tile it corresponds to. Each data-parallel - // unit maps to a single data-parallel tile, but each stream-K unit can map to more - // than one tile. Thus, we must offset the work-unit ID among the data-parallel units - // by the total number of output tiles that will be computed by stream-K units. - // - // The logic below also works for the split-K case, in which sk_units_ and sk_tiles_ - // are each 0. - uint64_t linear_work_idx = linear_idx - scheduler_params.sk_units_ + scheduler_params.sk_tiles_; - - // Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices - uint64_t work_idx_l, remainder; - scheduler_params.divmod_batch_(work_idx_l, remainder, linear_work_idx); - - uint64_t work_idx_k = 0; - if (is_split_k) { - scheduler_params.divmod_k_(work_idx_k, remainder, remainder); - } - - uint64_t cta_per_grid_dim, dontcare; - scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); - - auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( - cta_per_grid_dim, - scheduler_params.divmod_cluster_shape_major_, - scheduler_params.divmod_cluster_shape_minor_, - scheduler_params.divmod_cluster_blk_major_, - scheduler_params.log_swizzle_size_, - scheduler_params.raster_order_); - - bool is_final_split = (work_idx_k == scheduler_params.splits_ - 1); - - uint32_t k_iter = scheduler_params.k_iter_per_tile_; - if (is_split_k) { - // Determine the number of iterations and starting iteration of this split. - // Doing so requires accounting for residual iterations, which are handled - // by the first big_units_ splits (with big_units_ = tiles % sm_count). - - // Offsets for "normal" units. No additional k iterations are performed, - // and big_units_ "big" units preceded us, each of which performed one - // additional iteration. Thus, we must increase our split starting offset - // by big_units_. - int additional_k_iter = 0; - int split_start_offset = scheduler_params.big_units_; - - if (work_idx_k < scheduler_params.big_units_) { - // Offsets for "big" units. One additional k iteration is performed, - // and each split preceding us was a big unit, so we must increase - // our split starting offset by our split ID (work_idx_k). - additional_k_iter = 1; - split_start_offset = work_idx_k; - } - - // Set up k iteration count and split starting iteration assuming the - // iteration space is evenly split. - k_iter /= scheduler_params.splits_; - work_idx_k *= k_iter; - - // Apply any fixup needed to handle residuals - work_idx_k += split_start_offset; - k_iter += additional_k_iter; - } - - return { - work_idx_m, - work_idx_n, - static_cast(work_idx_k), - static_cast(work_idx_l), - true, - scheduler_params.k_iter_per_tile_, - k_iter, - k_iter, // remaining iterations - is_final_split - }; + // Bypass the stream-K scheduling logic for basic data-parallel or split-K work + return set_non_stream_k_work(linear_idx, params, is_split_k); + } + else { + // This is a stream-K work unit + WorkTileInfo work_tile_info; + set_stream_k_work(params, linear_idx, work_tile_info, /*new_unit = */ true); + return work_tile_info; } - - // This is a stream-K work unit - WorkTileInfo work_tile_info; - set_stream_k_work(linear_idx, work_tile_info, /*new_unit = */ true); - return work_tile_info; } // Returns whether the current work_tile_info passed in should continue to be used. This @@ -446,13 +373,24 @@ public: CUTLASS_DEVICE bool continue_current_work(WorkTileInfo& work_tile_info) const { + return continue_current_work_for_linear_idx( + current_work_linear_idx_, work_tile_info, scheduler_params); + } + + CUTLASS_DEVICE static + bool + continue_current_work_for_linear_idx( + uint64_t linear_idx, + WorkTileInfo& work_tile_info, + Params const& params) { + work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count; if (work_tile_info.k_tile_remaining == 0) { return false; } - set_stream_k_work(current_work_linear_idx_, work_tile_info, /* new_unit = */ false); + set_stream_k_work(params, linear_idx, work_tile_info, /* new_unit = */ false); return true; } @@ -495,6 +433,14 @@ public: /*truncate_by_problem_size=*/false); } + // Returns whether fixup is needed for `work_tile_info`. + CUTLASS_HOST_DEVICE + static bool + requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) { + // Fixup is not needed for data-parallel tiles + return work_tile_info.k_tile_count != params.k_tiles_per_output_tile_; + } + // Performs the reduction across splits for a given output tile. template CUTLASS_DEVICE @@ -505,13 +451,25 @@ public: FrgTensorC& accumulators, uint32_t num_barriers, uint32_t barrier_idx) { + using BarrierManager = NamedBarrierManager; + return fixup_helper( + params, work_tile_info, accumulators, num_barriers, barrier_idx); + } + + // Helper for performing the reduction across splits for a given output tile. + template + CUTLASS_DEVICE + static void + fixup_helper( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { using ElementAccumulator = typename FrgTensorC::value_type; - using BarrierManager = NamedBarrierManager; - - if (work_tile_info.k_tile_count == params.k_iter_per_tile_) { - // Fixup is not needed for data-parallel tiles + if (!requires_fixup(params, work_tile_info)) { return; } @@ -619,21 +577,23 @@ public: } } else { + auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); + uint64_t cta_per_grid_dim; uint64_t cluster_dim_idx; if (params.raster_order_ == RasterOrder::AlongN) { - uint64_t block_idx_m = (work_tile_info.M_idx - blockIdx.x) / gridDim.x; + uint64_t block_idx_m = (work_tile_info.M_idx - cta_m_in_cluster) / cute::size<0>(params.cluster_shape_); uint64_t block_idx_n = work_tile_info.N_idx; cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor * params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n; - cluster_dim_idx = blockIdx.x; + cluster_dim_idx = cta_m_in_cluster; } else { uint64_t block_idx_m = work_tile_info.M_idx; - uint64_t block_idx_n = (work_tile_info.N_idx - blockIdx.y) / gridDim.y; + uint64_t block_idx_n = (work_tile_info.N_idx - cta_n_in_cluster) / cute::size<1>(params.cluster_shape_); cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor * params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m; - cluster_dim_idx = blockIdx.y; + cluster_dim_idx = cta_n_in_cluster; } uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim; @@ -646,7 +606,7 @@ public: get_workspace_size( Arguments const& args, ProblemShape problem_shape, - KernelHardwareInfo const& hw_info, + KernelHardwareInfo const& hw_info, uint32_t mma_warp_groups) { int barrier_workspace_size = 0; @@ -715,7 +675,7 @@ private: // Construct a layout for the indexed tensor. The main purpose of this new layout is to // override the k extent to support cases in which the split computes a number of iterations - // not equal to total_tile_k_iter / splits. A common example of this is in stream-K is when a + // not equal to total_k_tiles / splits. A common example of this is in stream-K is when a // unit computes the final 20 of the total 32 k iterations of the output tile. In this case, // set splits = 32 and the split index (K_idx) to 11. The zipped divide above results in each // of the splits computing only one k iteration. @@ -728,12 +688,13 @@ private: // Returns the number of stream-K tiles that will be computed amongst `output_tiles` total // output tiles on a device with `ctas_per_wave` CTAs in each wave. static uint32_t - get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave) { + get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave, uint32_t k_tiles_per_output_tile) { uint32_t full_waves = static_cast(output_tiles / ctas_per_wave); uint32_t total_waves = static_cast((output_tiles + ctas_per_wave - 1) / ctas_per_wave); - if (full_waves == total_waves) { - // No quantization. All tiles will be data-parallel tiles. + if (full_waves == total_waves || k_tiles_per_output_tile == 1) { + // All tiles will be data-parallel tiles if there is either no quantization + // or if there is no work to be split. return 0; } @@ -811,9 +772,12 @@ private: sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); } + uint32_t k_tiles_per_output_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(TileShape{}) - 1) / + cute::size<2>(TileShape{}); + dim3 grid = get_grid_shape(problem_shape_mnkl, TileShape{}, cluster_shape, {0, sm_count}, args); uint64_t ctas_per_wave = grid.x * grid.y; - uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave); + uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile); barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups); reduction_workspace_size = get_reduction_workspace_size(sk_tiles); @@ -829,10 +793,10 @@ private: uint32_t blocks_l, ClusterShape cluster_shape, uint32_t splits, - uint32_t k_iter_per_tile, + uint32_t k_tiles_per_output_tile, void* reduction_workspace) { - uint32_t big_units = k_iter_per_tile % splits; + uint32_t big_units = k_tiles_per_output_tile % splits; return { underlying_params.divmod_cluster_shape_major_, @@ -845,7 +809,7 @@ private: underlying_params.raster_order_, cluster_shape, splits, - k_iter_per_tile, + k_tiles_per_output_tile, big_units, reduction_workspace }; @@ -855,8 +819,12 @@ private: // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining // iterations) is used to find the next tile in the current work unit. CUTLASS_DEVICE - void - set_stream_k_work(uint64_t linear_idx, WorkTileInfo& work_tile_info, bool new_unit) const { + static void + set_stream_k_work( + Params const& params, + uint64_t linear_idx, + WorkTileInfo& work_tile_info, + bool new_unit) { // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K // threadblock individually. For the most part, the set of K iterations corresponding to stream-K // work was divided amongst stream-K threadblocks, and a threadblock determined which tile @@ -872,15 +840,15 @@ private: // // To do so, we divide up the linearized stream-K units into clusters and share the same K // offsets for work within clusters. - auto cluster_linear_work_idx = linear_idx / size(scheduler_params.cluster_shape_); + auto cluster_linear_work_idx = linear_idx / size(params.cluster_shape_); // Determine the starting k iteration computed by this stream-K work unit - uint32_t unit_iter_start = scheduler_params.k_iter_per_sk_unit_ * cluster_linear_work_idx; + uint32_t unit_iter_start = params.k_tiles_per_sk_unit_ * cluster_linear_work_idx; // Adjust the starting position and number of k iterations for "big units," which // compute one extra iteration. These are the first big_units_ units in the // linearized ID space. - bool is_big_unit = cluster_linear_work_idx < scheduler_params.big_units_; + bool is_big_unit = cluster_linear_work_idx < params.big_units_; if (is_big_unit) { // Since the "big units" are the first units in the linearized ID space, each // of the units preceding this big unit computed one extra iteration. Thus, @@ -889,16 +857,16 @@ private: unit_iter_start += cluster_linear_work_idx; } else { // Increment by one for each of the big clusters (since all big units precede this unit) - unit_iter_start += scheduler_params.big_units_; + unit_iter_start += params.big_units_; } uint32_t unit_iters; if (new_unit) { - unit_iters = scheduler_params.k_iter_per_sk_unit_; + unit_iters = params.k_tiles_per_sk_unit_; // Only adjust iteration count for big unit if we are initializing this // work unit. For existing work units, the extra iteration for big units - // has already been accounted for in k_iter_reamaining + // has already been accounted for in k_tiles_reamaining if (is_big_unit) { ++unit_iters; } @@ -917,22 +885,21 @@ private: // for them to be computed later, so as to reduce the likelihood of blocking // on other work. uint32_t unit_iter_end = unit_iter_start + unit_iters - 1; - uint32_t true_tile_id = unit_iter_end / scheduler_params.k_iter_per_tile_; - uint32_t true_tile_iter_start = true_tile_id * scheduler_params.k_iter_per_tile_; - uint32_t true_tile_iter_end = true_tile_iter_start + scheduler_params.k_iter_per_tile_; + uint32_t true_tile_id = unit_iter_end / params.k_tiles_per_output_tile_; + uint32_t true_tile_iter_start = true_tile_id * params.k_tiles_per_output_tile_; + uint32_t true_tile_iter_end = true_tile_iter_start + params.k_tiles_per_output_tile_; // Bring the linearized tile ID back into the space of tiles, rather than clusters - true_tile_id *= size(scheduler_params.cluster_shape_); + true_tile_id *= size(params.cluster_shape_); - auto cluster_dim0 = cute::size<0>(scheduler_params.cluster_shape_); - auto cluster_dim1 = cute::size<1>(scheduler_params.cluster_shape_); + auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); // The final linearized tile ID is in units of the cluster dimension over which we rasterize. - if (scheduler_params.raster_order_ == RasterOrder::AlongN) { - true_tile_id += (blockIdx.y % cluster_dim1) * cluster_dim0; + if (params.raster_order_ == RasterOrder::AlongN) { + true_tile_id += cta_n_in_cluster * cute::size<0>(params.cluster_shape_); } else { - true_tile_id += (blockIdx.x % cluster_dim0) * cluster_dim1; + true_tile_id += cta_m_in_cluster * cute::size<1>(params.cluster_shape_); } // The unit's starting k iteration in the current tile is either the starting @@ -948,19 +915,18 @@ private: uint32_t tile_iters = tile_iter_end - tile_iter_start; uint64_t work_idx_l, remainder; - scheduler_params.divmod_batch_(work_idx_l, remainder, true_tile_id); + params.divmod_batch_(work_idx_l, remainder, true_tile_id); uint64_t cta_per_grid_dim, dontcare; - scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); - + params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( cta_per_grid_dim, - scheduler_params.divmod_cluster_shape_major_, - scheduler_params.divmod_cluster_shape_minor_, - scheduler_params.divmod_cluster_blk_major_, - scheduler_params.log_swizzle_size_, - scheduler_params.raster_order_); + params.divmod_cluster_shape_major_, + params.divmod_cluster_shape_minor_, + params.divmod_cluster_blk_major_, + params.log_swizzle_size_, + params.raster_order_); // // Update the work_tile_info @@ -971,11 +937,11 @@ private: work_tile_info.N_idx = work_idx_n; work_tile_info.L_idx = static_cast(work_idx_l); - // Set the k offset to be the starting k iteration for this tile + // Set the k offset to be the starting k tile for this output tile work_tile_info.K_idx = static_cast(tile_iter_start - true_tile_iter_start); - // Set the split count to be the number of k iterations in the tile - work_tile_info.splits = scheduler_params.k_iter_per_tile_; + // Set the split count to be the number of k tiles in the output tile + work_tile_info.splits = params.k_tiles_per_output_tile_; // Any checks for invalid work units should be done prior to this call work_tile_info.is_valid_tile = true; @@ -987,6 +953,89 @@ private: // the output tile in question work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end); } + + // Returns a WorkTileInfo to be computed for either the data-parallel or split-K + // work unit identified by the provided linear ID. + CUTLASS_DEVICE + static WorkTileInfo + set_non_stream_k_work(uint64_t linear_idx, Params const& params, bool is_split_k) { + + // The linearized ID space is in terms of work units, rather than tiles. However, + // to compute the correct block offset for a data-parallel tile, we must convert + // the current ID to the data-parallel tile it corresponds to. Each data-parallel + // unit maps to a single data-parallel tile, but each stream-K unit can map to more + // than one tile. Thus, we must offset the work-unit ID among the data-parallel units + // by the total number of output tiles that will be computed by stream-K units. + // + // The logic below also works for the split-K case, in which sk_units_ and sk_tiles_ + // are each 0. + uint64_t linear_work_idx = linear_idx - params.sk_units_ + params.sk_tiles_; + + // Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices + uint64_t work_idx_l, remainder; + params.divmod_batch_(work_idx_l, remainder, linear_work_idx); + + uint64_t work_idx_k = 0; + if (is_split_k) { + params.divmod_k_(work_idx_k, remainder, remainder); + } + + uint64_t cta_per_grid_dim, dontcare; + params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); + + auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( + cta_per_grid_dim, + params.divmod_cluster_shape_major_, + params.divmod_cluster_shape_minor_, + params.divmod_cluster_blk_major_, + params.log_swizzle_size_, + params.raster_order_); + + bool is_final_split = (work_idx_k == params.splits_ - 1); + + uint32_t k_tiles = params.k_tiles_per_output_tile_; + if (is_split_k) { + // Determine the number of iterations and starting iteration of this split. + // Doing so requires accounting for residual iterations, which are handled + // by the first big_units_ splits (with big_units_ = tiles % sm_count). + + // Offsets for "normal" units. No additional k iterations are performed, + // and big_units_ "big" units preceded us, each of which performed one + // additional iteration. Thus, we must increase our split starting offset + // by big_units_. + int additional_k_tiles = 0; + int split_start_offset = params.big_units_; + + if (work_idx_k < params.big_units_) { + // Offsets for "big" units. One additional k iteration is performed, + // and each split preceding us was a big unit, so we must increase + // our split starting offset by our split ID (work_idx_k). + additional_k_tiles = 1; + split_start_offset = work_idx_k; + } + + // Set up k iteration count and split starting iteration assuming the + // iteration space is evenly split. + k_tiles /= params.splits_; + work_idx_k *= k_tiles; + + // Apply any fixup needed to handle residuals + work_idx_k += split_start_offset; + k_tiles += additional_k_tiles; + } + + return { + work_idx_m, + work_idx_n, + static_cast(work_idx_k), + static_cast(work_idx_l), + true, + params.k_tiles_per_output_tile_, + k_tiles, + k_tiles, // remaining iterations + is_final_split + }; + } }; } // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 7ee6e03c..5618fe05 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -28,7 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/*! +/*! \file \brief Boost-like numeric conversion operator for CUTLASS numeric types */ @@ -55,7 +55,7 @@ enum class FloatRoundStyle { round_indeterminate, ///< rounding mode unknown round_toward_zero, ///< round toward zero round_to_nearest, ///< round to nearest even - round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type + round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type round_toward_infinity, ///< round toward infinity round_toward_neg_infinity, ///< round toward negative infinity round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero @@ -561,7 +561,7 @@ struct NumericConverter { // Note, the following is intentionally commented out. TF32 // does not define the low order bits, so they may be left in - // an undefined state. + // an undefined state. // // By not truncating these bit explicitly, we avoid an extra logical // operation. @@ -657,7 +657,7 @@ template < struct NumericConverterFastF32 { // result_type holds big tfloat32_t at idx(0) and small tfloat32_t at idx(1) - using result_type = Array; + using result_type = Array; // source data type using source_type = float; @@ -708,7 +708,7 @@ struct NumericConverterClamp { NumericConverter convert_op; result_type const kClamp_max = platform::numeric_limits::max(); result_type const kClamp_min = platform::numeric_limits::lowest(); - if (s < (source_type)kClamp_min) + if (s < (source_type)kClamp_min) return kClamp_min; if (s > (source_type)kClamp_max) return kClamp_max; @@ -848,7 +848,7 @@ struct NumericArrayConverter { result[0] = convert_(source[0]); result[1] = convert_(source[1]); #endif - + return result; } @@ -1044,7 +1044,7 @@ struct NumericArrayConverter { ///////////////////////////////////////////////////////////////////////////////////////////////// -// Conditional guards to enable partial specialization for packed integers +// Conditional guards to enable partial specialization for packed integers #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \ ((__CUDACC_VER_MAJOR__ > 10) || \ ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2))) @@ -1066,7 +1066,7 @@ struct NumericArrayConverter { result_type result; result[0] = convert_element_(source[0]); - + return result; } @@ -1189,7 +1189,7 @@ struct NumericArrayConverter { result_type result; result[0] = convert_element_(source[0]); - + return result; } diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index 706a1467..96fefd4b 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -215,10 +215,17 @@ class GemmArguments2x(ArgumentBase): else: self.batch_count = 1 - self.batched_stride_A = self.problem_size.m() * self.problem_size.k() - self.batched_stride_B = self.problem_size.n() * self.problem_size.k() - self.batched_stride_C = self.problem_size.m() * self.problem_size.n() - self.batched_stride_D = self.problem_size.m() * self.problem_size.n() + if "batch_strides" in kwargs: + self.batched_stride_A = kwargs["batch_strides"]["A"] + self.batched_stride_B = kwargs["batch_strides"]["B"] + self.batched_stride_C = kwargs["batch_strides"]["C"] + self.batched_stride_D = kwargs["batch_strides"]["D"] + else: + self.batched_stride_A = self.problem_size.m() * self.problem_size.k() + self.batched_stride_B = self.problem_size.n() * self.problem_size.k() + self.batched_stride_C = self.problem_size.m() * self.problem_size.n() + self.batched_stride_D = self.problem_size.m() * self.problem_size.n() + if self.bias: self.batched_stride_C = self.problem_size.n() diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index 9e70d474..31150aed 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -132,9 +132,9 @@ class KernelsForDataType: """ # Determine the leading dimension of the shape if layout == cutlass.LayoutType.ColumnMajor: - ld = shape[0] + ld = shape[-2] elif layout == cutlass.LayoutType.RowMajor: - ld = shape[1] + ld = shape[-1] elif layout == cutlass.LayoutType.TensorNHWC: ld = shape[-1] else: diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index 67d1f14e..fe8c5597 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -114,6 +114,8 @@ args.sync() """ +from math import prod + import cutlass_bindings import cutlass @@ -442,6 +444,113 @@ class Gemm(OperationBase): compiler.add_module([self.operation,]) return self.operation + def _verify_rank(self, tensor): + """ + Verifies that ``tensor`` has rank greater than 1 + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + """ + if len(tensor.shape) < 2: + raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}") + + def _get_batch_count(self, A, B, C, D) -> int: + """ + Returns the batch count specified by the tensors A, B, C, and D and verifies that these + tensors match in batch size. Presence of a batch dimension is detected by one of the + tensors being rank 3. If a batch dimension is present, it must be present in one of + operands A, B, or C (but need not be in all), and must be present in D. + + :param A: tensor A + :type A: numpy/cupy/torch array/tensor object + :param B: tensor B + :type B: numpy/cupy/torch array/tensor object + :param C: tensor C + :type C: numpy/cupy/torch array/tensor object + :param D: tensor D + :type D: numpy/cupy/torch array/tensor object + + :return: tuple of batch count dimensions + :rtype: tuple + """ + A_batch = A.shape[:-2] if len(A.shape) > 2 else tuple() + B_batch = B.shape[:-2] if len(B.shape) > 2 else tuple() + C_batch = C.shape[:-2] if len(C.shape) > 2 else tuple() + D_batch = D.shape[:-2] if len(D.shape) > 2 else tuple() + + if len(D_batch) > 0 and D_batch not in [A_batch, B_batch, C_batch]: + raise Exception(f"Batch count in D must be present in one of operands A, B, and C. " + f"Batch counts are: A={A_batch}, B={B_batch}, C={C_batch}, D={D_batch}") + + for batch_shape in [A_batch, B_batch, C_batch]: + if len(batch_shape) > 0 and batch_shape != D_batch: + raise Exception(f"Batch count for all other operands must either match that of D or be zero." + f"Received batch shape of {batch_shape}, which does not match that of D of {D_batch}.") + + return D_batch + + def _get_batch_stride(self, tensor) -> int: + """ + Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0. + + :param tensor: tensor object to process + :type tensor: numpy/cupy/torch array/tensor object + + :return: stride between each matrix in the batch + :rtype: int + """ + if len(tensor.shape) > 2: + return tensor.shape[-2] * tensor.shape[-1] + else: + return 0 + + def _get_problem_args(self, A, B, C, D) -> tuple: + """ + Returns the problem size and GEMM universal mode to use for the + given operands. + + :param A: tensor A + :type A: numpy/cupy/torch array/tensor object + :param B: tensor B + :type B: numpy/cupy/torch array/tensor object + :param C: tensor C + :type C: numpy/cupy/torch array/tensor object + :param D: tensor D + :type D: numpy/cupy/torch array/tensor object + + :return: tuple containing the problem size (cutlass_bindings.gemm.GemmCoord), the GEMM mode (cutlass_bindings.gemm.Mode), and the batch count (int) + :rtype: tuple + """ + M, K = A.shape[-2:] + N = B.shape[-1] + mode = cutlass_bindings.gemm.Mode.Gemm + + batch_count = self._get_batch_count(A, B, C, D) + returned_batch_count = prod(batch_count) if len(batch_count) > 0 else 1 + + # If we are running a batched GEMM in which there is a nonzero batch stride + # only for A, then we can fold the batched dimension of A into the M dimension + # (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A + # and C are row major. A similar operation can be performed if only B has a nonzero + # batch dimension + if len(batch_count) > 0: + A_row = self._layout_a == cutlass.LayoutType.RowMajor + B_row = self._layout_b == cutlass.LayoutType.RowMajor + C_row = self._layout_c == cutlass.LayoutType.RowMajor + + batched = lambda x : len(x.shape) == 2 + len(batch_count) + + if batched(A) and not batched(B) and batched(C) and A_row and C_row: + M *= prod(batch_count) + returned_batch_count = 1 + elif not batched(A) and batched(B) and batched(C) and not B_row and not C_row: + N *= prod(batch_count) + returned_batch_count = 1 + else: + mode = cutlass_bindings.gemm.Mode.Batched + + return cutlass_bindings.gemm.GemmCoord(M, N, K), mode, returned_batch_count + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): """ Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception @@ -461,8 +570,7 @@ class Gemm(OperationBase): f'layout of ({ref_type}, {ref_layout}).') def run(self, A=None, B=None, C=None, D=None, - alpha=None, beta=None, batch_count: int = 1, - sync: bool = True, print_module: bool = False) -> GemmArguments: + alpha=None, beta=None, sync: bool = True, print_module: bool = False) -> GemmArguments: """ Runs the kernel currently specified. If it has not already been, the kernel is emitted and compiled. Tensors holding operands and outputs of the kernel are sourced either from the @@ -481,8 +589,6 @@ class Gemm(OperationBase): :param D: tensor representing data type and layout of operand D :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B :param beta: scalar parameter beta from GEMM operation that scales operand C - :param batch_count: number of GEMMs in the batch - :type batch_count: int :param sync: whether the call should wait for the kernel to complete before returning :type sync: bool :param print_module: whether to print the emitted C++ code @@ -491,9 +597,6 @@ class Gemm(OperationBase): :return: arguments passed in to the kernel :rtype: cutlass.backend.GemmArguments """ - if batch_count < 1: - raise Exception(f"Invalid batch count {batch_count}. Value must be an integer >= 1.") - A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") @@ -501,20 +604,31 @@ class Gemm(OperationBase): alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + self._verify_rank(A) + self._verify_rank(B) + self._verify_rank(C) + self._verify_rank(D) + alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a) alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b) alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c) self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, alignment_C=alignment_c, print_module=print_module) - problem_size = cutlass_bindings.gemm.GemmCoord(A.shape[0], B.shape[1], A.shape[1]) + problem_size, mode, batch_count = self._get_problem_args(A, B, C, D) - if batch_count == 1: - mode = cutlass_bindings.gemm.Mode.Gemm + if mode == cutlass_bindings.gemm.Mode.Gemm or batch_count == 1: kwargs = {'split_k_slices': 1} else: - mode = cutlass_bindings.gemm.Mode.Batched - kwargs = {'batch': batch_count} + kwargs = { + 'batch': batch_count, + 'batch_strides': { + 'A': self._get_batch_stride(A), + 'B': self._get_batch_stride(B), + 'C': self._get_batch_stride(C), + 'D': self._get_batch_stride(D) + } + } arguments = GemmArguments( operation=self.operation, problem_size=problem_size, diff --git a/test/python/gemm/gemm_batched.py b/test/python/gemm/gemm_batched.py new file mode 100644 index 00000000..69b425c9 --- /dev/null +++ b/test/python/gemm/gemm_batched.py @@ -0,0 +1,139 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +High-level tests for running batched GEMMs +""" + +from functools import partial +from math import prod + +import cutlass +import logging +import torch +import unittest + +from cutlass.backend.test.utils import LayoutCombination, add_test_gemm +from cutlass.backend.utils.device import device_cc + +cutlass.set_log_level(logging.WARNING) + +torch.manual_seed(2023) + + +def pytorch_reference(A, B, C, alpha, beta): + # Get the batch count. Assume that any of A, B, and C + # with a batch dimension ahve matching batch count. Thus, + # we break out of the loop once we have found the first + # tensor containing a batch dimension. + batch_count = (1,) + for tensor in [A, B, C]: + if len(tensor.shape) > 2: + batch_count = tensor.shape[:-2] + break + + int_batch_count = prod(batch_count) + + def add_batch(tensor): + if len(tensor.shape) == 2: + return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1) + else: + return tensor.reshape(-1, tensor.size(-2), tensor.size(-1)) + + # Reshape tensors to have batch dimension + A = add_batch(A) + B = add_batch(B) + C = add_batch(C) + + ret = (torch.bmm(A, B) * alpha) + (C * beta) + reshape_vals = batch_count + C.shape[-2:] + return ret.reshape(*reshape_vals) + + +def initialize(rows, cols, batch): + tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half() + if len(batch) > 0 and prod(batch) > 1: + reshape_vals = batch + (rows, cols) + return tensor.reshape(*reshape_vals) + else: + return tensor.reshape(rows, cols) + + +class GemmF16Batched(unittest.TestCase): + def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool): + M = 512 + N = 256 + K = 128 + alpha = 1. + beta = 2. + + A = initialize(M, K, batch_count if batch_A else (1,)) + B = initialize(K, N, batch_count if batch_B else (1,)) + C = initialize(M, N, batch_count if batch_C else (1,)) + D = initialize(M, N, batch_count) + + plan = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass.DataType.f32) + plan.run(A, B, C, D, alpha, beta) + reference = pytorch_reference(A, B, C, alpha, beta) + assert reference.equal(D) + + def test_batched_ABC(self): + self.run_batched((3,), True, True, True) + self.run_batched((2, 3), True, True, True) + + def test_batched_AB(self): + self.run_batched((3,), True, True, False) + self.run_batched((2, 3), True, True, False) + + def test_batched_AC(self): + self.run_batched((3,), True, False, True) + self.run_batched((2, 3), True, False, True) + + def test_batched_BC(self): + self.run_batched((3,), False, True, True) + self.run_batched((2, 3), False, True, True) + + def test_batched_A(self): + self.run_batched((3,), True, False, False) + self.run_batched((2, 3), True, False, False) + + def test_batched_B(self): + self.run_batched((3,), False, True, False) + self.run_batched((2, 3), False, True, False) + + def test_batched_C(self): + self.run_batched((3,), False, False, True) + self.run_batched((2, 3), False, False, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index 8d7a2968..4faea525 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -61,7 +61,7 @@ __global__ void convert( ///////////////////////////////////////////////////////////////////////////////////////////////// template -void run_test() { +void run_test(const char dest_name[], const char source_name[]) { const int kN = Count; dim3 grid(1, 1); @@ -84,7 +84,10 @@ void run_test() { destination.sync_host(); for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); + EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])) + << "Destination type: " << dest_name + << ", Source type: " << source_name + << ", Count: " << Count; } } @@ -97,15 +100,19 @@ void run_test() { TEST(NumericConversion, f32_to_f16_rn) { int const kN = 1; using Source = float; + const char source_name[] = "float"; using Destination = cutlass::half_t; - test::core::kernel::run_test(); + const char dest_name[] = "half_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f32x8_to_f16x8_rn) { int const kN = 8; using Source = float; + const char source_name[] = "float"; using Destination = cutlass::half_t; - test::core::kernel::run_test(); + const char dest_name[] = "half_t"; + test::core::kernel::run_test(dest_name, source_name); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -113,15 +120,19 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) { TEST(NumericConversion, f16_to_f32_rn) { int const kN = 1; using Source = cutlass::half_t; + const char source_name[] = "half_t"; using Destination = float; - test::core::kernel::run_test(); + const char dest_name[] = "float"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f16x8_to_f32x8_rn) { int const kN = 8; using Source = cutlass::half_t; + const char source_name[] = "half_t"; using Destination = float; - test::core::kernel::run_test(); + const char dest_name[] = "float"; + test::core::kernel::run_test(dest_name, source_name); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -129,86 +140,109 @@ TEST(NumericConversion, f16x8_to_f32x8_rn) { TEST(NumericConversion, f32_to_fe4m3_rn) { int const kN = 1; using Source = float; + const char source_name[] = "float"; using Destination = cutlass::float_e4m3_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f32_to_fe4m3_rn_array) { int const kN = 27; using Source = float; + const char source_name[] = "float"; using Destination = cutlass::float_e4m3_t; - - test::core::kernel::run_test(); + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f32_to_fe5m2_rn) { int const kN = 1; using Source = float; + const char source_name[] = "float"; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f32_to_fe5m2_rn_array) { int const kN = 27; using Source = float; + const char source_name[] = "float"; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f16_to_fe4m3_rn) { int const kN = 1; using Source = cutlass::half_t; + const char source_name[] = "half_t"; using Destination = cutlass::float_e4m3_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f16_to_fe4m3_rn_array) { int const kN = 27; using Source = cutlass::half_t; + const char source_name[] = "half_t"; using Destination = cutlass::float_e4m3_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f16_to_fe5m2_rn) { int const kN = 1; using Source = cutlass::half_t; + const char source_name[] = "half_t"; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, f16_to_fe5m2_rn_array) { int const kN = 27; using Source = cutlass::half_t; + const char source_name[] = "half_t"; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, bf16_to_fe4m3_rn) { int const kN = 1; using Source = cutlass::bfloat16_t; + const char source_name[] = "bfloat16_t"; using Destination = cutlass::float_e4m3_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, bf16_to_fe4m3_rn_array) { int const kN = 27; using Source = cutlass::bfloat16_t; + const char source_name[] = "bfloat16_t"; using Destination = cutlass::float_e4m3_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, bf16_to_fe5m2_rn) { int const kN = 1; using Source = cutlass::bfloat16_t; + const char source_name[] = "bfloat16_t"; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, bf16_to_fe5m2_rn_array) { int const kN = 27; using Source = cutlass::bfloat16_t; + const char source_name[] = "bfloat16_t"; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -216,36 +250,46 @@ TEST(NumericConversion, bf16_to_fe5m2_rn_array) { TEST(NumericConversion, fe4m3_to_fe5m2_rn) { int const kN = 1; using Source = cutlass::float_e4m3_t; + const char source_name[] = "float_e4m3_t"; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe4m3_to_fe5m2_array) { int const kN = 27; using Source = cutlass::float_e4m3_t; + const char source_name[] = "float_e4m3_t"; using Destination = cutlass::float_e5m2_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe5m2_to_fe4m3_rn) { int const kN = 1; using Source = cutlass::float_e5m2_t; + const char source_name[] = "float_e5m2_t"; using Destination = cutlass::float_e4m3_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe5m2_to_fe4m3_array) { int const kN = 27; using Source = cutlass::float_e5m2_t; + const char source_name[] = "float_e5m2_t"; using Destination = cutlass::float_e4m3_t; - test::core::kernel::run_test(); + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe4m3_to_f32_rn) { int const kN = 1; using Source = cutlass::float_e4m3_t; + const char source_name[] = "float_e4m3_t"; using Destination = float; - test::core::kernel::run_test(); + const char dest_name[] = "float"; + test::core::kernel::run_test(dest_name, source_name); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -254,78 +298,100 @@ TEST(NumericConversion, f32x8_to_s8x8_rn) { int const kN = 8; using Source = float; + const char source_name[] = "float"; using Destination = int8_t; - test::core::kernel::run_test(); + const char dest_name[] = "int8_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe4m3_to_f32_array) { int const kN = 27; using Source = cutlass::float_e4m3_t; + const char source_name[] = "float_e4m3_t"; using Destination = float; - test::core::kernel::run_test(); + const char dest_name[] = "float"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe5m2_to_f32_array) { int const kN = 27; using Source = cutlass::float_e5m2_t; + const char source_name[] = "float_e5m2_t"; using Destination = float; - test::core::kernel::run_test(); + const char dest_name[] = "float"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe4m3_to_f16_rn) { int const kN = 1; using Source = cutlass::float_e4m3_t; + const char source_name[] = "float_e4m3_t"; using Destination = cutlass::half_t; - test::core::kernel::run_test(); + const char dest_name[] = "half_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe4m3_to_f16_array) { int const kN = 27; using Source = cutlass::float_e4m3_t; + const char source_name[] = "float_e4m3_t"; using Destination = cutlass::half_t; - test::core::kernel::run_test(); + const char dest_name[] = "half_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe5m2_to_f16_rn) { int const kN = 1; using Source = cutlass::float_e5m2_t; + const char source_name[] = "float_e5m2_t"; using Destination = cutlass::half_t; - test::core::kernel::run_test(); + const char dest_name[] = "half_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe5m2_to_f16_array) { int const kN = 27; using Source = cutlass::float_e5m2_t; + const char source_name[] = "float_e5m2_t"; using Destination = cutlass::half_t; - test::core::kernel::run_test(); + const char dest_name[] = "half_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe4m3_to_bf16_rn) { int const kN = 1; using Source = cutlass::float_e4m3_t; + const char source_name[] = "float_e4m3_t"; using Destination = cutlass::bfloat16_t; - test::core::kernel::run_test(); + const char dest_name[] = "bfloat16_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe4m3_to_bf16_array) { int const kN = 27; using Source = cutlass::float_e4m3_t; + const char source_name[] = "float_e4m3_t"; using Destination = cutlass::bfloat16_t; - test::core::kernel::run_test(); + const char dest_name[] = "bfloat16_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe5m2_to_bf16_rn) { int const kN = 1; using Source = cutlass::float_e5m2_t; + const char source_name[] = "float_e5m2_t"; using Destination = cutlass::bfloat16_t; - test::core::kernel::run_test(); + const char dest_name[] = "bfloat16_t"; + test::core::kernel::run_test(dest_name, source_name); } TEST(NumericConversion, fe5m2_to_bf16_array) { int const kN = 27; using Source = cutlass::float_e5m2_t; + const char source_name[] = "float_e5m2_t"; using Destination = cutlass::bfloat16_t; - test::core::kernel::run_test(); + const char dest_name[] = "bfloat16_t"; + test::core::kernel::run_test(dest_name, source_name); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/cute/core/CMakeLists.txt b/test/unit/cute/core/CMakeLists.txt index 30c05d24..a8eb4fbe 100644 --- a/test/unit/cute/core/CMakeLists.txt +++ b/test/unit/cute/core/CMakeLists.txt @@ -36,6 +36,7 @@ cutlass_test_unit_add_executable( compare.cpp complement.cpp composition.cpp + constant_arithmetic.cpp core_unit.cpp inverse_left.cpp inverse_right.cpp diff --git a/test/unit/cute/core/constant_arithmetic.cpp b/test/unit/cute/core/constant_arithmetic.cpp new file mode 100644 index 00000000..c11da9e3 --- /dev/null +++ b/test/unit/cute/core/constant_arithmetic.cpp @@ -0,0 +1,106 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" +#include +#include + +TEST(CuTe_core, ConstantArithmetic) { + using namespace cute; + + constexpr cute::integral_constant uzero{}; + + // This extra test exists historically as part of the diagnosis + // of a possible Clang 14 bug. However, it's a nice test for + // cute::integral_constant's arithmetic operators, so it's saved here. + // It also demonstrates how to work with cute::integral_constant + // and lambda captures. Microsoft Visual Studio ("MSVC") tends to + // disagree with other compilers about the meaning of decltype + // for variables captured by reference. MSVC and GCC 8.3.0 + // also tend to disagree with other compilers (and other GCC versions) + // about whether expressions involving such variables + // are constant expressions. + // + // A typical CuTe idiom is to do lambda captures by reference [&]. + // This test changes them to capture by value, except for + // the innermost lambda's capture of S1, which is by reference. + // The point is to show that MSVC and GCC 8 have issues with this + // that other compilers do not. For example, + // + // 1. MSVC needs remove_cvref_t around decltype(S1) + // in order to access decltype(S1)::value, and + // 2. MSVC and GCC 8.3.0 both report a build error with S1() + // (that is, calling operator() on S1, which returns the + // same thing as S1.value). + // + // The reason for (2) is that neither compiler thinks + // that S1() is a constant expression. + // + // This leaves S1.value as the most concise portable expression + // for the "value" member of a cute::integral_constant. + for_each(make_integer_sequence{}, [uzero](auto S0) { + for_each(make_integer_sequence{}, [uzero,S0](auto F0) { + for_each(make_integer_sequence{}, [uzero,S0,F0](auto S1) { + for_each(make_integer_sequence{}, [uzero,S0,F0,&S1](auto F1) { + static_assert((decltype(S0)::value & decltype(F0)::value) == decltype(S0 & F0)::value); + + // Using S1.value means you don't have to use remove_cvref_t + // with a captured-by-reference variable. + static_assert((cute::remove_cvref_t::value & decltype(F1)::value) == decltype(S1 & F1)::value); + static_assert((S1.value & decltype(F1)::value) == decltype(S1 & F1)::value); + // S1() _should_ work, but does not with Visual Studio 2022, + // which emits C2131 ("expression did not evaluate to a constant"). + // It also does not with GCC 8.3.0, which emits an error with messages + // "non-constant condition for static assertion" and + // "'this' is not a constant expression." + // + //static_assert((S1() & decltype(F1)::value) == decltype(S1 & F1)::value); + static_assert(decltype((S0 & F0) != uzero)::value == ((decltype(S0)::value & decltype(F0)::value) != 0)); + + static_assert(decltype((S1 & F1) != uzero)::value == ((cute::remove_cvref_t::value & decltype(F1)::value) != 0)); + static_assert(decltype((S1 & F1) != uzero)::value == ((S1.value & decltype(F1)::value) != 0)); + + constexpr bool left = decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value; + constexpr bool right = + ((decltype(S0)::value & decltype(F0)::value) != 0) || + ((cute::remove_cvref_t::value & decltype(F1)::value) != 0); + constexpr bool right2 = + ((decltype(S0)::value & decltype(F0)::value) != 0) || + ((S1.value & decltype(F1)::value) != 0); + static_assert(right == right2); + static_assert(left == right); + constexpr bool left2 = decltype((S0 & F0) != uzero)::value || decltype((S1 & F1) != uzero)::value; + static_assert(left == left2); + }); + }); + }); + }); +} diff --git a/test/unit/cute/core/mixedbits.cpp b/test/unit/cute/core/mixedbits.cpp index 4fd79650..f439566e 100644 --- a/test/unit/cute/core/mixedbits.cpp +++ b/test/unit/cute/core/mixedbits.cpp @@ -31,9 +31,58 @@ #include "cutlass_unit_test.h" +// C::value_type is not uint32_t for GCC 7.5.0. +// This test is thus disabled for GCC < 8. +#if defined(__GNUC__) && (__GNUC__ < 8) + #include #include +namespace { // (anonymous) + + // This function exists to work around a Clang 14 issue, in which + // the compiler tries to instantiate code that lives inside the + // "else" branch of an "if constexpr," even when the "else" branch + // is false. That triggers a spurious static_assert in MixedBits. + // The work-around is to make the body of the "else" branch a + // function, rather than leaving it in line. + // + // Some compilers strangely deduce the first two terms of + // make_integer_sequence as C and C, and + // the remaining terms as C<2>, C<3>, etc. Making this function take + // cute::integral_constant, etc. doesn't work + // with those compilers. + template + void clang14_workaround(cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant) + { + constexpr cute::C(S0_value)> S0{}; + constexpr cute::C(F0_value)> F0{}; + constexpr cute::C(S1_value)> S1{}; + constexpr cute::C(F1_value)> F1{}; + + for (uint32_t d0 = 0; d0 < 8; ++d0) { + if ((d0 & F0) != d0) { continue; } // Skip repeats + for (uint32_t d1 = 0; d1 < 8; ++d1) { + if ((d1 & F1) != d1) { continue; } // Skip repeats + auto m0 = make_mixed_bits(S0, d0, F0); + auto m1 = make_mixed_bits(S1, d1, F1); + //print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n"); + EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1)); + //print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n"); + EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1)); + //print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n"); + EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1)); + } + } + } +} // namespace (anonymous) + TEST(CuTe_core, MixedBits) { using namespace cute; @@ -48,23 +97,21 @@ TEST(CuTe_core, MixedBits) { } else if constexpr (decltype((S0 & F0) != uzero || (S1 & F1) != uzero)::value) { return; } else { - for (uint32_t d0 = 0; d0 < 8; ++d0) { - if ((d0 & F0) != d0) { continue; } // Skip repeats - for (uint32_t d1 = 0; d1 < 8; ++d1) { - if ((d1 & F1) != d1) { continue; } // Skip repeats - auto m0 = make_mixed_bits(S0, d0, F0); - auto m1 = make_mixed_bits(S1, d1, F1); - //print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n"); - EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1)); - //print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n"); - EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1)); - //print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n"); - EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1)); - } - } + clang14_workaround(S0, F0, S1, F1); } }); }); }); }); } + +TEST(CuTe_core, MakeIntegerSequence) { + cute::for_each(cute::make_integer_sequence{}, [](auto c) { + using c_type = decltype(c); + constexpr auto c_value = c_type::value; + using expected_type = cute::integral_constant; + static_assert(cute::is_same_v); + }); +} + +#endif // defined(__GNUC__) && (__GNUC__ < 8) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu index f307a682..fc4e3f31 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu index f879927d..45b9d023 100644 --- a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu index 8e6a40d5..989d60aa 100644 --- a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu +++ b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu @@ -32,6 +32,7 @@ \brief Tests that the stream-K scheduler covers the entire problem space. */ +#include "cutlass/cluster_launch.hpp" #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" #include "cutlass/util/device_memory.h" @@ -39,6 +40,10 @@ #include "../../common/cutlass_unit_test.h" +// Grids are launched with clusters enabled in these tests, +// so the CTK version must support cluster launching. +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + using namespace cute; using ProblemShape_MNKL = Shape; @@ -60,7 +65,7 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape while (work_tile_info.is_valid_tile) { // Increment counters to indicate coverage auto tile_idx = Scheduler::output_tile_index(params, work_tile_info); - auto offset = tile_idx * params.k_iter_per_tile_ + work_tile_info.K_idx; + auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx; for (auto i = 0; i < work_tile_info.k_tile_count; ++i) { // Use atomicAdd because the visit counters are shared by multiple thread blocks. // While having more than one block increment the same counter indicates failure, @@ -103,7 +108,7 @@ test_scheduler( // Allocate counters indicating the number of times each k iteration of each output tile has been visited auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); - auto total_counters = blk_m * blk_n * blk_l * params.k_iter_per_tile_; + auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_; cutlass::DeviceAllocation visit_counters(total_counters); // Initialize counters to zero @@ -118,12 +123,55 @@ test_scheduler( // Set up the grid for the problem dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args); + // Set up cluster and cluster launch. This is needed even for this simple kernel because + // the SM90 scheduler needs to be able to query the CTA id within a cluster, which requires + // explicitly launching with clusters. + dim3 cluster{ + static_cast(cute::get<0>(ClusterShape{})), + static_cast(cute::get<1>(ClusterShape{})), + static_cast(cute::get<2>(ClusterShape{})) + }; + + cudaLaunchConfig_t launch_config; + launch_config.gridDim = grid; + launch_config.blockDim = {1, 1, 1}; + launch_config.dynamicSmemBytes = 0; + launch_config.stream = NULL; + + cudaLaunchAttribute launch_attribute[1]; + launch_attribute[0].id = cudaLaunchAttributeClusterDimension; + launch_attribute[0].val.clusterDim.x = cluster.x; + launch_attribute[0].val.clusterDim.y = cluster.y; + launch_attribute[0].val.clusterDim.z = cluster.z; + + launch_config.attrs = launch_attribute; + launch_config.numAttrs = 1; + + void const* kernel = (void const*) run_scheduler; + int* counters_ptr = visit_counters.get(); + void* kernel_params[] = { + &counters_ptr, + ¶ms, + &tile_shape, + &cluster_shape, + &problem_shape_mnkl + }; + // Run the scheduler to completion and log visits to each k iteration - run_scheduler<<>>( - visit_counters.get(), params, tile_shape, cluster_shape, problem_shape_mnkl); + err = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); + + if (err != cudaSuccess) { + std::cerr << __FILE__ << ":" << __LINE__ + << " cudaLaunchKernelExC failed with error: " + << cudaGetErrorString(err) << std::endl; + return false; + } + err = cudaDeviceSynchronize(); if (err != cudaSuccess) { - std::cerr << __FILE__ << ":" << __LINE__ << " scheduler kernel failed with error: " << cudaGetErrorString(err) << std::endl; + std::cerr << __FILE__ << ":" << __LINE__ + << " scheduler kernel failed with error: " + << cudaGetErrorString(err) << std::endl; return false; } @@ -143,11 +191,11 @@ test_scheduler( << " and grid size " << grid.x << "x" << grid.y << "x" << grid.z << " splits=" << params.splits_ - << " k_iter=" << params.k_iter_per_tile_ + << " k_iter=" << params.k_tiles_per_output_tile_ << " big_units=" << params.big_units_ << " sk_tiles=" << params.sk_tiles_ << " sk_units=" << params.sk_units_ - << " k_iter_per_sk_unit=" << params.k_iter_per_sk_unit_ << std::endl; + << " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl; std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl; return false; } @@ -274,4 +322,6 @@ TEST(SM90_Device_Gemm_stream_k_scheduler, 128x128x64_2x1x1) { EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114)); } +#endif // defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index a0370ede..d248643d 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -179,7 +179,7 @@ class GemmOperation: ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] if self.arch >= 90: - kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}{t}" + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{t}{k}{e}" return kernel_name_template.format( p = self.prefix, ar = self.arch, @@ -194,9 +194,9 @@ class GemmOperation: l = self.tile_description.stages, s = self.layout_name_3x(), al = str(max(self.A.alignment, self.B.alignment)), + t = TileSchedulerSuffixes[self.tile_scheduler], k = self.kernel_schedule_name_3x(), - e = self.epilogue_schedule_name_3x(), - t = TileSchedulerSuffixes[self.tile_scheduler]) + e = self.epilogue_schedule_name_3x()) else: threadblock = self.tile_description.procedural_name() return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( @@ -661,8 +661,7 @@ using ${operation_name}_mainloop = ${element_accumulator}, cute::Shape, cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout< - sizeof(typename ${operation_name}_epilogue::SharedStorage)>, + ${stages}, ${kernel_schedule} >::CollectiveOp; @@ -697,7 +696,7 @@ ${compile_guard_end} if operation.tile_description.stages > 0: stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" else: - stage_count_string = "cutlass::gemm::collective::StageCountAuto" + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout" warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)] instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \ diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 630364c1..be169bf0 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -4218,6 +4218,8 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): layout[2][1] = 8 CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) + # persistent kernels with TMA epilogues if CudaToolkitVersionSatisfies(cuda_version, 12, 1): CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed,