From e1483d5fa0c9eeee4589d11978d5908314b4bb56 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 8 Dec 2023 20:42:12 +0100 Subject: [PATCH] Collection of changes to fix clang build. (#1200) * Remove unused variables * Qualify calls to make_fragment_? from templated base class. Fixes clang build error. * Add missing `#include ` * Various changes to fix clang compile errors. * More changes to fix clang build. Remaining issues: - `params` initializer of `CollectiveEpilogue`. - `ops` initializer of `Sm90VisitorImplBase`. - `__usAtomicCAS` needs to be added to clang upstream. * Fix remaining clang build issues. * Qualify `cute::rank()` calls. * Qualify some more calls that are otherwise ambiguous between `cute` and `std` namespace. * Double-escape special registers in inline asm. * small change --------- Co-authored-by: Haicheng Wu --- examples/51_hopper_gett/51_hopper_gett.cu | 18 +++---- .../gather_gemm.hpp | 12 ++--- .../scatter_epilogue.hpp | 10 ++-- .../53_hopper_gemm_permute.cu | 6 +-- .../53_hopper_gemm_permute/permute_traits.hpp | 4 +- include/cute/arch/cluster_sm90.hpp | 26 +++++----- include/cute/config.hpp | 7 +-- include/cutlass/cluster_launch.hpp | 1 + .../collective/builders/sm90_builder.inl | 6 +-- .../epilogue/collective/default_epilogue.hpp | 10 ++-- .../cutlass/epilogue/collective/detail.hpp | 4 +- .../collective/epilogue_tensor_broadcast.hpp | 10 ++-- .../collective/sm70_epilogue_vectorized.hpp | 10 ++-- .../sm90_epilogue_tma_warpspecialized.hpp | 26 +++++----- .../sm90_callbacks_tma_warpspecialized.hpp | 8 +-- ...90_visitor_compute_tma_warpspecialized.hpp | 50 +++++++++++++------ .../sm90_visitor_load_tma_warpspecialized.hpp | 6 ++- ...sm90_visitor_store_tma_warpspecialized.hpp | 13 +++-- .../sm90_visitor_tma_warpspecialized.hpp | 50 +++++++++++++------ .../threadblock/fusion/visitor_2x.hpp | 2 +- .../gemm/collective/sm70_mma_twostage.hpp | 16 +++--- .../gemm/collective/sm80_mma_multistage.hpp | 16 +++--- ...mma_multistage_gmma_rs_warpspecialized.hpp | 12 ++--- ...mma_multistage_gmma_ss_warpspecialized.hpp | 8 +-- .../sm90_mma_tma_gmma_rs_warpspecialized.hpp | 14 +++--- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 14 +++--- .../gemm/collective/sm90_mma_tma_gmma_ss.hpp | 14 +++--- .../sm90_mma_tma_gmma_ss_warpspecialized.hpp | 10 ++-- ...90_mma_tma_gmma_ss_warpspecialized_fp8.hpp | 10 ++-- .../gemm/kernel/gemm_universal_streamk.h | 4 +- include/cutlass/gemm/kernel/sm70_gemm.hpp | 14 +++--- include/cutlass/gemm/kernel/sm90_gemm_tma.hpp | 16 +++--- .../kernel/sm90_gemm_tma_warpspecialized.hpp | 14 +++--- ...0_gemm_tma_warpspecialized_cooperative.hpp | 14 +++--- ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 14 +++--- .../gemm/kernel/sm90_gemm_warpspecialized.hpp | 12 ++--- .../sm90_gemm_warpspecialized_cooperative.hpp | 12 ++--- .../sm90_gemm_warpspecialized_pingpong.hpp | 12 ++--- include/cutlass/gemm_coord.hpp | 1 + include/cutlass/numeric_conversion.h | 2 +- test/unit/gemm/device/gemm_testbed_3x.hpp | 11 ++-- test/unit/gemm/device/gemm_testbed_3x_evt.hpp | 32 ++++++------ .../gemm_testbed_3x_tensor_broadcast.hpp | 4 +- test/unit/gemm/device/sm90_evt_operations.hpp | 16 +++--- .../cutlass/util/reference/device/gett.hpp | 2 +- .../cutlass/util/reference/host/gett.hpp | 8 +-- 46 files changed, 308 insertions(+), 273 deletions(-) diff --git a/examples/51_hopper_gett/51_hopper_gett.cu b/examples/51_hopper_gett/51_hopper_gett.cu index e9969505..0a2622e4 100644 --- a/examples/51_hopper_gett/51_hopper_gett.cu +++ b/examples/51_hopper_gett/51_hopper_gett.cu @@ -186,15 +186,15 @@ main(int argc, char const* argv[]) { using ElementEpilogue = float; // The following constexpr values set the max number of modes in each MNKL mode - constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes - constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes - constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes - constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes - static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{})); - static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{})); - static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{})); - static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{})); - static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{})); + constexpr int MaxRank_M = cute::rank(RowModeStridesA{}); // Max row modes + constexpr int MaxRank_N = cute::rank(ColModeStridesB{}); // Max column modes + constexpr int MaxRank_K = cute::rank(RedModeStridesA{}); // Max contraction modes + constexpr int MaxRank_L = cute::rank(BatModeStridesA{}); // Max batch modes + static_assert(cute::rank(RowModeStridesA{}) == cute::rank(RowModeStridesC{})); + static_assert(cute::rank(ColModeStridesB{}) == cute::rank(RowModeStridesC{})); + static_assert(cute::rank(RedModeStridesA{}) == cute::rank(RedModeStridesB{})); + static_assert(cute::rank(BatModeStridesA{}) == cute::rank(BatModeStridesC{})); + static_assert(cute::rank(BatModeStridesB{}) == cute::rank(BatModeStridesC{})); // Parse command line to get modes, extents, and strides cutlass::GettCommandLine cmd; diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp index 07de1639..067b1dce 100644 --- a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -58,7 +58,7 @@ public: // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -180,7 +180,7 @@ public: bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; @@ -288,10 +288,10 @@ public: PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); // Preconditions - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) diff --git a/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp b/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp index f08c107c..4dc555e4 100644 --- a/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp +++ b/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp @@ -86,8 +86,8 @@ public: static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { }; @@ -151,10 +151,10 @@ public: using namespace cute; using X = Underscore; - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); (void) smem_buf; ThreadEpilogueOp epilogue_op{params.thread_params}; diff --git a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu index 8800615f..aa93304b 100644 --- a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu +++ b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu @@ -197,14 +197,14 @@ template auto select_mode_shape(Shapes const & ... shapes) { auto permuted_shapes = filter_tuple(cute::make_tuple(shapes...), [](auto shape) { - if constexpr (rank(shape) > 1) { + if constexpr (cute::rank(shape) > 1) { return cute::make_tuple(shape); } else { return cute::make_tuple(); } }); - if constexpr (rank(permuted_shapes) == 0) { + if constexpr (cute::rank(permuted_shapes) == 0) { return get<0>(cute::make_tuple(shapes...)); } else { @@ -251,7 +251,7 @@ auto select_tile_shape(TileSize size, Shape const& shape) { static_assert(is_static::value, "Tile size must be static"); - if constexpr (rank(Shape{}) == 0) { + if constexpr (cute::rank(Shape{}) == 0) { return cute::make_tuple(size); } else { diff --git a/examples/53_hopper_gemm_permute/permute_traits.hpp b/examples/53_hopper_gemm_permute/permute_traits.hpp index 55c76418..4ec6094e 100644 --- a/examples/53_hopper_gemm_permute/permute_traits.hpp +++ b/examples/53_hopper_gemm_permute/permute_traits.hpp @@ -78,7 +78,7 @@ reshape(Shape const& shape, TargetShape const& target_shape) template constexpr auto make_permute_layout(Layout const& layout) { - static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); if constexpr (Transpose) { // Deal with tensor B by transposing appropriately before and after computing the permute layout. // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. @@ -135,7 +135,7 @@ using inverse_t = decltype(inverse(T{})); template constexpr auto make_original_layout(Layout const& layout) { - static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); if constexpr (Transpose) { // Deal with tensor B by transposing appropriately before and after computing the permute layout. // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp index 7e909712..b034b2cd 100644 --- a/include/cute/arch/cluster_sm90.hpp +++ b/include/cute/arch/cluster_sm90.hpp @@ -86,9 +86,9 @@ CUTE_DEVICE dim3 cluster_grid_dims() { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t x, y, z; - asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : ); - asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : ); - asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : ); + asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : ); return {x, y, z}; #elif defined(__CUDA_ARCH__) // MSVC requires protecting use of gridDim with __CUDA_ARCH__. @@ -105,9 +105,9 @@ CUTE_DEVICE dim3 cluster_id_in_grid() { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t x, y, z; - asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : ); - asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : ); - asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : ); + asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : ); return {x, y, z}; #elif defined(__CUDA_ARCH__) // MSVC requires protecting use of blockIdx with __CUDA_ARCH__. @@ -124,9 +124,9 @@ CUTE_DEVICE dim3 block_id_in_cluster() { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t x, y, z; - asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : ); - asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : ); - asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : ); return {x, y, z}; #else return {0,0,0}; @@ -138,9 +138,9 @@ CUTE_DEVICE dim3 cluster_shape() { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t x, y, z; - asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : ); - asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : ); - asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : ); return {x, y, z}; #else return {1,1,1}; @@ -152,7 +152,7 @@ CUTLASS_DEVICE uint32_t block_rank_in_cluster() { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t rank; - asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :); + asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :); return rank; #else return 0; diff --git a/include/cute/config.hpp b/include/cute/config.hpp index ba2504cd..e4bda683 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -30,7 +30,7 @@ **************************************************************************************************/ #pragma once -#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) # define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ # define CUTE_DEVICE __forceinline__ __device__ # define CUTE_HOST __forceinline__ __host__ @@ -46,10 +46,11 @@ # define CUTE_HOST_RTC CUTE_HOST #endif -#if !defined(__CUDACC_RTC__) && (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) +#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \ + (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) # define CUTE_UNROLL #pragma unroll # define CUTE_NO_UNROLL #pragma unroll 1 -#elif defined(__CUDACC_RTC__) +#elif defined(__CUDACC_RTC__) || defined(__clang__) # define CUTE_UNROLL _Pragma("unroll") # define CUTE_NO_UNROLL _Pragma("unroll 1") #else diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 28611d51..21923641 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -35,6 +35,7 @@ #pragma once +#include #include #include "cutlass/cutlass.h" #include "cutlass/trace.h" diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 8580547b..54940f67 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -595,7 +595,7 @@ CollectiveBuilder< cute::is_base_of_v >> { private: using FusionOp = - fusion::LinCombEltAct; + fusion::LinCombEltAct; using ImplSchedule = cute::conditional_t, TmaWarpSpecialized, TmaWarpSpecializedCooperative>; @@ -676,7 +676,7 @@ private: using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator< GmemStrideTypeAux, typename Schedule::ElementT>()); using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux< - GmemLayoutTagD, Schedule::ActivationFunctor, ElementD, ElementCompute, + GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute, typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute >; using FusionCallbacksAux = fusion::FusionCallbacks< @@ -684,7 +684,7 @@ private: >; using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct< - Schedule::ActivationFunctor, ElementD, ElementCompute, + Schedule::template ActivationFunctor, ElementD, ElementCompute, typename Schedule::ElementBias, ElementCompute >; using FusionCallbacksNoAux = fusion::FusionCallbacks< diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index fbfde723..99286cec 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -81,8 +81,8 @@ public: static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { }; @@ -163,10 +163,10 @@ public: using namespace cute; using X = Underscore; - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); // Separate out problem shape for convenience auto M = get<0>(problem_shape_mnkl); diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index d871dc23..e6106343 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -204,12 +204,12 @@ public: int thread_idx, TensorStorage& shared_tensors) { - constexpr int BLK_M_RANK = rank<0>(tile_shape_MNK); + constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK); auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); })); - constexpr int BLK_N_RANK = rank<1>(tile_shape_MNK); + constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK); auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); })); diff --git a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp index b9001b12..accc6d9d 100644 --- a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp +++ b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -91,8 +91,8 @@ public: using StrideD = StrideD_; using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; @@ -182,10 +182,10 @@ public: using namespace cute; using X = Underscore; - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4"); // Separate out problem shape for convenience auto M = get<0>(problem_shape_mnkl); diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index 9e91f834..8f4c11f8 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -87,8 +87,8 @@ public: static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { @@ -172,10 +172,10 @@ public: using namespace cute; using X = Underscore; - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); // synchronizing function for smem reads/writes #if CUDA_BARRIER_ENABLED diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 27b5f37b..dcb6a07a 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -113,12 +113,12 @@ public: using GmemTiledCopyD = SM90_TMA_STORE; static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); - static_assert(rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); - static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); private: using SmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages @@ -340,10 +340,10 @@ public: auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; // Tile residue - auto m_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { + auto m_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); })); - auto n_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { + auto n_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); })); auto residue_mn = make_coord(m_max_coord, n_max_coord); @@ -456,11 +456,11 @@ public: using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; static_assert(is_rmem::value, "Accumulator must be RF resident."); - static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(cute::rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "TileShapeMNK must be static"); - static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); - static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + static_assert(cute::rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(cute::rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; @@ -530,11 +530,11 @@ public: Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) // Coordinate tensors and residue for tile quantization - auto m_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { + auto m_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl); return cute::max(0, c_m); })); - auto n_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { + auto n_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl); return cute::max(0, c_n); })); @@ -559,7 +559,7 @@ public: tRS_cD, tRS_rC }; - auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index a3767b78..da392ce1 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -695,7 +695,7 @@ template< FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90ScaledLinCombPerRowBiasEltAct = - Sm90EVT::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d Sm90EVT, // activation(Z) // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias Sm90ScaledLinCombPerRowBias @@ -829,7 +829,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias Sm90ScaledLinCombPerRowBias, // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) - Sm90EVT::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d Sm90EVT, // amax_d Sm90EVT, // activation(Z) Sm90SplitTreeFetch // Z @@ -839,7 +839,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = >, // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) Sm90EVT, // store(Aux) - Sm90EVT::Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux + Sm90EVT::template Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux Sm90EVT, // amax_aux Sm90SplitTreeFetch // Z >, @@ -1021,7 +1021,7 @@ template< using Sm90LinCombDeEltAct = Sm90EVT, // activation(beta * C + (alpha * acc), aux) Sm90LinearCombination, // beta * C + (alpha * acc) - Sm90AuxLoad, // aux + Sm90AuxLoad // aux >; template < diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 687cb293..e5332fc1 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -237,6 +237,18 @@ struct Sm90TreeVisitor< Sm90Compute >; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + CUTLASS_DEVICE bool is_producer_load_needed() const { auto const& bcast_op = get<0>(Impl::ops); @@ -252,8 +264,6 @@ struct Sm90TreeVisitor< return bcast_op.scalar != 0 || added_op.is_C_load_needed(); } - using Impl::Sm90VisitorImpl; - template struct ConsumerStoreCallbacks : CallbacksImpl { CUTLASS_DEVICE @@ -301,10 +311,9 @@ struct Sm90TreeVisitor< > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks( - is_C_load_needed(), - Impl::get_consumer_store_callbacks(args) - ); + auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks( + is_C_load_needed(), std::move(callbacks_tuple)); } }; @@ -475,7 +484,8 @@ struct Sm90ReLUAuxStore { gAux, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params); } }; } // namespace detail @@ -532,7 +542,17 @@ struct Sm90TreeVisitor< Sm90Compute >; - using Impl::Sm90VisitorImpl; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} template struct ConsumerStoreCallbacks : CallbacksImpl { @@ -556,9 +576,8 @@ struct Sm90TreeVisitor< > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks( - Impl::get_consumer_store_callbacks(args) - ); + auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); } }; @@ -654,7 +673,7 @@ struct Sm90AuxLoad< CUTLASS_DEVICE void begin() { - if constexpr (decltype(rank(tC_rAux))::value == 5) { + if constexpr (decltype(cute::rank(tC_rAux))::value == 5) { if constexpr (EnableNullptr) { if (params.ptr_aux == nullptr) { return; @@ -669,7 +688,7 @@ struct Sm90AuxLoad< CUTLASS_DEVICE void previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { - if constexpr (decltype(rank(tC_rAux))::value == 3) { + if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { if constexpr (EnableNullptr) { if (params.ptr_aux == nullptr) { return; @@ -686,7 +705,7 @@ struct Sm90AuxLoad< CUTLASS_DEVICE auto visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { using ElementRegister = typename remove_cvref_t::value_type; - if constexpr (decltype(rank(tC_rAux))::value == 3) { + if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { return recast>(coalesce(tC_rAux))(epi_v); } else { @@ -727,7 +746,8 @@ struct Sm90AuxLoad< } } - return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params); } }; diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index b60dc2c7..df3b988b 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -280,7 +280,8 @@ struct Sm90AuxLoad { Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) - return ProducerLoadCallbacks(cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); + return ProducerLoadCallbacks( + cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); } template @@ -344,7 +345,8 @@ struct Sm90AuxLoad { auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) - return ConsumerStoreCallbacks(cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); } }; diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 2330c30b..13780f3e 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -268,7 +268,7 @@ struct Sm90AuxStore { Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) - return ConsumerStoreCallbacks( + return ConsumerStoreCallbacks( cute::move(tC_rAux), tiled_r2s, cute::move(tRS_sAux), @@ -1109,12 +1109,11 @@ public: Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L) Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N) - return ConsumerStoreCallbacks( - make_tuple(bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, - args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx), - params - ); + auto args_tuple = make_tuple( + bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx); + return ConsumerStoreCallbacks(std::move(args_tuple), params); } }; diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index bca7d1ce..30ac4761 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -272,8 +272,18 @@ struct Sm90VisitorImplBase { template struct Sm90VisitorImpl : Sm90VisitorImplBase { - using Sm90VisitorImplBase::Sm90VisitorImplBase; - using Sm90VisitorImplBase::ops; + using Impl = Sm90VisitorImplBase; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90VisitorImpl() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + using Impl::ops; // // Queries for kernel runtime @@ -506,7 +516,18 @@ using namespace detail; template struct Sm90TreeVisitor : Sm90VisitorImpl { - using Sm90VisitorImpl::Sm90VisitorImpl; + using Impl = Sm90VisitorImpl; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} template struct ConsumerStoreCallbacks : CallbacksImpl { @@ -538,10 +559,9 @@ struct Sm90TreeVisitor : Sm90VisitorImpl { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks( - Sm90VisitorImpl:: - get_consumer_store_callbacks(args) - ); + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); } }; @@ -590,10 +610,9 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks( - Sm90VisitorImpl:: - get_consumer_store_callbacks(args) - ); + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); } }; @@ -609,7 +628,7 @@ template< > struct Sm90TopologicalVisitor : Sm90VisitorImpl { static_assert(is_static_v); - static_assert(rank(EdgeTuple{}) == sizeof...(Ops)); + static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops)); static_assert(sizeof...(Ops) > 1); using Sm90VisitorImpl::Sm90VisitorImpl; @@ -669,10 +688,9 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks( - Sm90VisitorImpl:: - get_consumer_store_callbacks(args) - ); + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); } }; diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp index cb602400..f5da084f 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_2x.hpp @@ -232,7 +232,7 @@ template< > struct TopologicalVisitor2x : VisitorImpl2x { static_assert(is_static_v); - static_assert(rank(EdgeTuple{}) == sizeof...(Ops)); + static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops)); static_assert(sizeof...(Ops) > 1); using VisitorImpl2x::VisitorImpl2x; diff --git a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp index ffe1ea6d..f2c5ca99 100644 --- a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp +++ b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp @@ -100,11 +100,11 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -173,9 +173,9 @@ struct CollectiveMma< static_assert(is_gmem::value, "A tensor must be gmem resident."); static_assert(is_gmem::value, "B tensor must be gmem resident."); static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 2, + static_assert(cute::rank(SmemLayoutA{}) == 2, "MainloopTwoStage must not have a smem shape with a pipeline mode."); - static_assert(rank(SmemLayoutB{}) == 2, + static_assert(cute::rank(SmemLayoutB{}) == 2, "MainloopTwoStage must not have a smem shape with a pipeline mode."); // Construct shared memory tiles @@ -343,11 +343,11 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -414,9 +414,9 @@ struct CollectiveMma< static_assert(is_gmem::value, "A tensor must be gmem resident."); static_assert(is_gmem::value, "B tensor must be gmem resident."); static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 2, + static_assert(cute::rank(SmemLayoutA{}) == 2, "MainloopTwoStage must not have a smem shape with a pipeline mode."); - static_assert(rank(SmemLayoutB{}) == 2, + static_assert(cute::rank(SmemLayoutB{}) == 2, "MainloopTwoStage must not have a smem shape with a pipeline mode."); // Construct shared memory tiles diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp index dc98823c..be70ac7c 100644 --- a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp +++ b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -101,11 +101,11 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -174,9 +174,9 @@ struct CollectiveMma< static_assert(is_gmem::value, "A tensor must be gmem resident."); static_assert(is_gmem::value, "B tensor must be gmem resident."); static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 3, + static_assert(cute::rank(SmemLayoutA{}) == 3, "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); - static_assert(rank(SmemLayoutB{}) == 3, + static_assert(cute::rank(SmemLayoutB{}) == 3, "MainloopSm80CpAsync must have a pipeline mode in the smem layout."); // Construct shared memory tiles @@ -390,11 +390,11 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -463,8 +463,8 @@ struct CollectiveMma< static_assert(is_gmem::value, "A tensor must be gmem resident."); static_assert(is_gmem::value, "B tensor must be gmem resident."); static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); // Construct shared memory tiles SharedStorage& storage = *reinterpret_cast(smem_buf); diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp index 1e1c5e6d..a7cc5ddb 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp @@ -138,11 +138,11 @@ struct CollectiveMma< using PipelineState = typename MainloopPipeline::PipelineState; using PipelineParams = typename MainloopPipeline::Params; - static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -418,10 +418,10 @@ struct CollectiveMma< { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); - static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); static_assert(!cute::is_void_v, "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); static_assert(cute::is_void_v, diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp index 1b74153f..1312d731 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp @@ -112,11 +112,11 @@ struct CollectiveMma< using PipelineState = typename MainloopPipeline::PipelineState; using PipelineParams = typename MainloopPipeline::Params; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -346,8 +346,8 @@ struct CollectiveMma< using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::is_void_v, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); static_assert(cute::is_void_v, diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index 2928192b..e9867351 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -141,11 +141,11 @@ struct CollectiveMma< using PipelineParams = typename MainloopPipeline::Params; - static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -402,7 +402,7 @@ struct CollectiveMma< // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(tiled_tensors); @@ -502,10 +502,10 @@ struct CollectiveMma< { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); - static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); static_assert(!cute::is_void_v, "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); static_assert(cute::is_void_v, diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index f4295070..4d50793e 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -183,11 +183,11 @@ public: using PipelineParams = typename MainloopPipeline::Params; - static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -443,7 +443,7 @@ public: // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(tiled_tensors); @@ -541,10 +541,10 @@ public: Params const& mainloop_params) { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); - static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); static_assert(!cute::is_void_v, "SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions."); static_assert(cute::is_void_v, diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index af14f137..38e6ca1d 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -113,11 +113,11 @@ struct CollectiveMma< using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename cutlass::PipelineState; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -271,10 +271,10 @@ struct CollectiveMma< using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::is_void_v, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); static_assert(cute::is_void_v, @@ -288,7 +288,7 @@ struct CollectiveMma< // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index b0656ca4..1dffea46 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -114,11 +114,11 @@ struct CollectiveMma< using PipelineParams = typename MainloopPipeline::Params; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -319,7 +319,7 @@ struct CollectiveMma< // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(tiled_tensors); @@ -423,8 +423,8 @@ struct CollectiveMma< using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::is_void_v, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); static_assert(cute::is_void_v, diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index 7a67b245..09f97f8e 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -115,11 +115,11 @@ struct CollectiveMma< using PipelineParams = typename MainloopPipeline::Params; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -317,7 +317,7 @@ struct CollectiveMma< // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(tiled_tensors); @@ -421,8 +421,8 @@ struct CollectiveMma< using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::is_void_v, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); static_assert(cute::is_void_v, diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index eaa3cfc1..9c35f270 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -665,7 +665,7 @@ protected: int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; int m_end = params.block_mapping.problem_size.m(); - return Mma::IteratorA( + return typename Mma::IteratorA( params.params_A, ptr_A, { m_end, tile_work.k_end }, @@ -694,7 +694,7 @@ protected: int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; int n_end = params.block_mapping.problem_size.n(); - return Mma::IteratorB( + return typename Mma::IteratorB( params.params_B, ptr_B, { tile_work.k_end, n_end }, diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp index 21393cda..e22fb436 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -60,7 +60,7 @@ public: // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -142,7 +142,7 @@ public: static bool can_implement(Arguments const& args) { return args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); } static int @@ -159,7 +159,7 @@ public: static dim3 get_grid_shape(Params const& params) { int batch_count = 1; - if constexpr (rank(ProblemShape{}) == 4) { + if constexpr (cute::rank(ProblemShape{}) == 4) { batch_count = cute::size<3>(params.problem_shape); } @@ -193,10 +193,10 @@ public: auto L = get<3>(problem_shape_MNKL); // Preconditions - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); // Get the appropriate blocks for this thread block -- potential for thread block locality int thread_idx = int(threadIdx.x); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 8091672f..14cb47e1 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -80,7 +80,7 @@ public: // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -169,7 +169,7 @@ public: bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; @@ -219,10 +219,10 @@ public: #endif // Preconditions - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); int thread_idx = int(threadIdx.x); int warp_idx = canonical_warp_idx_sync(); @@ -285,13 +285,13 @@ public: params.mainloop ); - constexpr int BLK_M_RANK = rank<0>(blk_shape); + constexpr int BLK_M_RANK = cute::rank<0>(blk_shape); bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); })); - constexpr int BLK_N_RANK = rank<1>(blk_shape); + constexpr int BLK_N_RANK = cute::rank<1>(blk_shape); bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index cb4baf4d..2ec1aa0e 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -69,7 +69,7 @@ public: // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -176,7 +176,7 @@ public: bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; @@ -318,10 +318,10 @@ public: } (); // Preconditions - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); @@ -338,7 +338,7 @@ public: // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); - static_assert(tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + static_assert(cute::tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); // Extract out partitioned A and B. Tensor gA_mkl = get<0>(tiled_tensors); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 61e94e92..551bb231 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -69,7 +69,7 @@ public: // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -219,7 +219,7 @@ public: bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; @@ -303,10 +303,10 @@ public: static_assert(size<0>(TileShape{}) >= 128, "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ enum class WarpGroupRole { @@ -423,7 +423,7 @@ public: // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); - static_assert(tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + static_assert(cute::tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); // Extract out partitioned A and B. Tensor gA_mkl = get<0>(tiled_tensors); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index b34e56f2..dc92e931 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -70,7 +70,7 @@ public: // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -225,7 +225,7 @@ public: bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; @@ -305,10 +305,10 @@ public: #endif // Preconditions - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); enum class WarpGroupRole { Producer = 0, @@ -427,7 +427,7 @@ public: // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); - static_assert(tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + static_assert(cute::tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); // Extract out partitioned A and B. Tensor gA_mkl = get<0>(tiled_tensors); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp index bdcfa4ef..4c8901bc 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -67,7 +67,7 @@ public: // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -180,7 +180,7 @@ public: bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; @@ -289,10 +289,10 @@ public: PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); // Preconditions - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 7a8c5fa7..d1c6dc84 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -67,7 +67,7 @@ public: // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -200,7 +200,7 @@ public: bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; @@ -256,10 +256,10 @@ public: } #endif - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); /* In the Cooperative kernel, one or multiple Consumers collaborate on the same tile */ enum class WarpGroupRole { diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index f43ff562..5a6571d8 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -69,7 +69,7 @@ public: // Type Aliases // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -212,7 +212,7 @@ public: bool can_implement(Arguments const& args) { bool implementable = (args.mode == GemmUniversalMode::kGemm) or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; @@ -265,10 +265,10 @@ public: #endif // Preconditions - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); enum class WarpGroupRole { Producer = 0, diff --git a/include/cutlass/gemm_coord.hpp b/include/cutlass/gemm_coord.hpp index 390de0c5..a0d2babe 100644 --- a/include/cutlass/gemm_coord.hpp +++ b/include/cutlass/gemm_coord.hpp @@ -35,6 +35,7 @@ #pragma once +#include "cute/layout.hpp" #include "cutlass/gemm_coord.h" namespace cutlass { diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 6de5d038..4d2faab0 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -192,7 +192,7 @@ struct NumericConverter { return static_cast(intermediate); } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 68953804..86e146a5 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -193,8 +193,8 @@ struct TestbedImpl { using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static constexpr uint32_t mma_promotion_interval = 4; @@ -523,9 +523,6 @@ struct TestbedImpl { Gemm& gemm_op, typename Gemm::Arguments& arguments, cutlass::device_memory::allocation& workspace) { - int M = cute::size<0>(problem_size); - int N = cute::size<1>(problem_size); - int K = cute::size<2>(problem_size); int L = 1; if constexpr(cute::rank(ProblemShapeType{}) == 4) { L = cute::size<3>(problem_size); @@ -581,7 +578,7 @@ struct TestbedImpl { cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; if (not profiling) { - this->sm_count = min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); hw_info.sm_count = this->sm_count; } else { @@ -1240,7 +1237,7 @@ struct Testbed3xFusionOperation { hw_info.device_id = 0; if (not profiling) { - impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); hw_info.sm_count = impl_.sm_count; } else { diff --git a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp index 9127d407..1a21840d 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -173,7 +173,7 @@ public: HostScalarBroadcast(){} template HostScalarBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) - :_scalar(ElementCompute(Value)), Base(check_relative_equality) {} + : Base(check_relative_equality), _scalar(ElementCompute(Value)) {} template ElementCompute visit( @@ -232,7 +232,7 @@ public: HostRowBroadcast(){} template HostRowBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) - :impl_(impl), Base(check_relative_equality) { + : Base(check_relative_equality), impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); _N = cute::get<1>(problem_shape_MNKL); _bias.resize(cutlass::Coord<1>(_N)); @@ -300,7 +300,7 @@ public: HostColBroadcast(){} template HostColBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) - :impl_(impl), Base(check_relative_equality) { + : Base(check_relative_equality), impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); _M = cute::get<0>(problem_shape_MNKL); _bias.resize(cutlass::Coord<1>(_M)); @@ -382,7 +382,7 @@ public: HostAuxLoad(){} template HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) - :impl_(impl), Base(check_relative_equality){ + : Base(check_relative_equality), impl_(impl){ auto problem_shape_NMKL = cute::append<4>(problem_size, 1); auto [_M, _N, K, _L] = problem_shape_NMKL; auto aux_coord = cutlass::make_Coord(_M * _L, _N); @@ -513,8 +513,8 @@ public: HostUnaryCompute(){} template HostUnaryCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): - _child_0(problem_size, impl, check_relative_equality), - Base(check_relative_equality) { } + Base(check_relative_equality), + _child_0(problem_size, impl, check_relative_equality) { } template ElementCompute visit( @@ -578,8 +578,8 @@ public: HostAuxStore(){} template HostAuxStore(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): - impl_(impl), - Base(check_relative_equality) { + Base(check_relative_equality), + impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [_M, _N, K, _L] = problem_shape_MNKL; auto aux_coord = cutlass::make_Coord(_M * _L, _N); @@ -677,8 +677,8 @@ public: HostRowReduce(){} template HostRowReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): - impl_(impl), - Base(check_relative_equality) { + Base(check_relative_equality), + impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); _N = cute::get<1>(problem_shape_MNKL); _tensor_row_reduce.resize(cutlass::Coord<1>(_N)); @@ -764,8 +764,8 @@ public: HostColumnReduce(){} template HostColumnReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): - impl_(impl), - Base(check_relative_equality) { + Base(check_relative_equality), + impl_(impl) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); _M = cute::get<0>(problem_shape_MNKL); _tensor_column_reduce.resize(cutlass::Coord<1>(_M)); @@ -850,9 +850,8 @@ public: HostScalarReduce(){} template HostScalarReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): - impl_(impl), - Base(check_relative_equality) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + Base(check_relative_equality), + impl_(impl) { _tensor_scalar_reduce.resize(cutlass::Coord<1>(1)); _reference_scalar_reduce.resize(cutlass::Coord<1>(1)); _reduce_buffer.resize(cutlass::Coord<1>(1)); @@ -1229,7 +1228,6 @@ public: auto N = cute::get<1>(problem_shape_MNKL); auto K = cute::get<2>(problem_shape_MNKL); auto L = cute::get<3>(problem_shape_MNKL); - auto coord_0 = cutlass::make_Coord(0); auto A = cute::make_tensor(impl_.tensor_A.host_data(), cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); @@ -1307,7 +1305,7 @@ public: cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; if (not profiling) { - impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); hw_info.sm_count = impl_.sm_count; } else { diff --git a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp index 145f8747..ec7fbea7 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -158,7 +158,6 @@ struct Testbed3xTensorBroadcast { bool use_bias) { auto [M, N, K, L] = problem_shape_MNKL; - auto coord_0 = cutlass::make_Coord(0); impl_.tensor_D.sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0); @@ -218,7 +217,6 @@ struct Testbed3xTensorBroadcast { auto N = cute::get<1>(problem_shape_MNKL); auto K = cute::get<2>(problem_shape_MNKL); auto L = cute::get<3>(problem_shape_MNKL); - auto coord_0 = cutlass::make_Coord(0); auto A = cute::make_tensor(impl_.tensor_A.host_data(), cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); @@ -338,7 +336,7 @@ struct Testbed3xTensorBroadcast { cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; if (not profiling) { - impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); hw_info.sm_count = impl_.sm_count; } else { diff --git a/test/unit/gemm/device/sm90_evt_operations.hpp b/test/unit/gemm/device/sm90_evt_operations.hpp index b7e7231b..71c6f2bb 100644 --- a/test/unit/gemm/device/sm90_evt_operations.hpp +++ b/test/unit/gemm/device/sm90_evt_operations.hpp @@ -163,7 +163,7 @@ public: using EVTModule = HEVT< HostAuxStore, HEVT< - HostCompute::Op>, // activation(Z) * scaled_d + HostCompute::template Op>, // activation(Z) * scaled_d HEVT< HostCompute, // activation(Z) HEVT< @@ -174,11 +174,11 @@ public: HostCompute, HostScalarBroadcast, // scale_a * scale_b * alpha HostAccumulator, - HostColBroadcast, + HostColBroadcast > > >, - HostScalarBroadcast, // scale_d + HostScalarBroadcast // scale_d > >; }; @@ -211,26 +211,26 @@ public: HostCompute, HostScalarBroadcast, // scale_a * scale_b * alpha HostAccumulator, - HostColBroadcast, + HostColBroadcast > >, // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) HEVT< - HostCompute::Op>, + HostCompute::template Op>, HEVT< HostScalarReduce, HEVT< HostCompute, //activation(Z) * scaled_d - HostAccumulator, // Z + HostAccumulator // Z > >, - HostScalarBroadcast, // scale_d + HostScalarBroadcast // scale_d >, // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) HEVT< HostAuxStore, HEVT< - HostCompute::Op>, + HostCompute::template Op>, HEVT< HostScalarReduce, HostAccumulator diff --git a/tools/util/include/cutlass/util/reference/device/gett.hpp b/tools/util/include/cutlass/util/reference/device/gett.hpp index 84b7037e..dc247ecb 100644 --- a/tools/util/include/cutlass/util/reference/device/gett.hpp +++ b/tools/util/include/cutlass/util/reference/device/gett.hpp @@ -126,7 +126,7 @@ gett( cudaStream_t stream = 0) { using namespace cute; - static_assert(rank(ProblemShapeMNKL{}) == 4); + static_assert(cute::rank(ProblemShapeMNKL{}) == 4); auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); auto K = get<2>(problem_shape_mnkl); diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 60a22814..5e75d08b 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -431,11 +431,11 @@ void Gemm3x( { using namespace cute; - static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename MainloopParams::LayoutB{})); - static_assert(rank(typename EpilogueParams::LayoutC{}) == rank(typename EpilogueParams::LayoutD{})); - static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); - if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) { + if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) { Layout layout_A = make_layout_rank3(mainloop_params.A); Layout layout_B = make_layout_rank3(mainloop_params.B); Layout layout_C = make_layout_rank3(epilogue_params.C);