Collection of changes to fix clang build. (#1200)

* Remove unused variables

* Qualify calls to make_fragment_? from templated base class.

Fixes clang build error.

* Add missing `#include <cstdio>`

* Various changes to fix clang compile errors.

* More changes to fix clang build.

Remaining issues:

- `params` initializer of `CollectiveEpilogue`.
- `ops` initializer of `Sm90VisitorImplBase`.
- `__usAtomicCAS` needs to be added to clang upstream.

* Fix remaining clang build issues.

* Qualify `cute::rank()` calls.

* Qualify some more calls that are otherwise ambiguous between `cute` and `std` namespace.

* Double-escape special registers in inline asm.

* small change

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Christian Sigg 2023-12-08 20:42:12 +01:00 committed by GitHub
parent f4a0216601
commit e1483d5fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 308 additions and 273 deletions

View File

@ -186,15 +186,15 @@ main(int argc, char const* argv[]) {
using ElementEpilogue = float; using ElementEpilogue = float;
// The following constexpr values set the max number of modes in each MNKL mode // The following constexpr values set the max number of modes in each MNKL mode
constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes constexpr int MaxRank_M = cute::rank(RowModeStridesA{}); // Max row modes
constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes constexpr int MaxRank_N = cute::rank(ColModeStridesB{}); // Max column modes
constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes constexpr int MaxRank_K = cute::rank(RedModeStridesA{}); // Max contraction modes
constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes constexpr int MaxRank_L = cute::rank(BatModeStridesA{}); // Max batch modes
static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{})); static_assert(cute::rank(RowModeStridesA{}) == cute::rank(RowModeStridesC{}));
static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{})); static_assert(cute::rank(ColModeStridesB{}) == cute::rank(RowModeStridesC{}));
static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{})); static_assert(cute::rank(RedModeStridesA{}) == cute::rank(RedModeStridesB{}));
static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{})); static_assert(cute::rank(BatModeStridesA{}) == cute::rank(BatModeStridesC{}));
static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{})); static_assert(cute::rank(BatModeStridesB{}) == cute::rank(BatModeStridesC{}));
// Parse command line to get modes, extents, and strides // Parse command line to get modes, extents, and strides
cutlass::GettCommandLine cmd; cutlass::GettCommandLine cmd;

View File

@ -58,7 +58,7 @@ public:
// Type Aliases // Type Aliases
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -180,7 +180,7 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable; return implementable;
@ -288,10 +288,10 @@ public:
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>(); PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
// Preconditions // Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Separate out problem shape for convenience // Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)

View File

@ -86,8 +86,8 @@ public:
static const int kOutputAlignment = ThreadEpilogueOp::kCount; static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type; using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { }; struct SharedStorage { };
@ -151,10 +151,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static"); static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
(void) smem_buf; (void) smem_buf;
ThreadEpilogueOp epilogue_op{params.thread_params}; ThreadEpilogueOp epilogue_op{params.thread_params};

View File

@ -197,14 +197,14 @@ template<class ... Shapes>
auto auto
select_mode_shape(Shapes const & ... shapes) { select_mode_shape(Shapes const & ... shapes) {
auto permuted_shapes = filter_tuple(cute::make_tuple(shapes...), [](auto shape) { auto permuted_shapes = filter_tuple(cute::make_tuple(shapes...), [](auto shape) {
if constexpr (rank(shape) > 1) { if constexpr (cute::rank(shape) > 1) {
return cute::make_tuple(shape); return cute::make_tuple(shape);
} }
else { else {
return cute::make_tuple(); return cute::make_tuple();
} }
}); });
if constexpr (rank(permuted_shapes) == 0) { if constexpr (cute::rank(permuted_shapes) == 0) {
return get<0>(cute::make_tuple(shapes...)); return get<0>(cute::make_tuple(shapes...));
} }
else { else {
@ -251,7 +251,7 @@ auto
select_tile_shape(TileSize size, Shape const& shape) select_tile_shape(TileSize size, Shape const& shape)
{ {
static_assert(is_static<TileSize>::value, "Tile size must be static"); static_assert(is_static<TileSize>::value, "Tile size must be static");
if constexpr (rank(Shape{}) == 0) { if constexpr (cute::rank(Shape{}) == 0) {
return cute::make_tuple(size); return cute::make_tuple(size);
} }
else { else {

View File

@ -78,7 +78,7 @@ reshape(Shape const& shape, TargetShape const& target_shape)
template<class Permute, bool Transpose, class Shape, class Stride> template<class Permute, bool Transpose, class Shape, class Stride>
constexpr auto constexpr auto
make_permute_layout(Layout<Shape,Stride> const& layout) { make_permute_layout(Layout<Shape,Stride> const& layout) {
static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported"); static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
if constexpr (Transpose) { if constexpr (Transpose) {
// Deal with tensor B by transposing appropriately before and after computing the permute layout. // Deal with tensor B by transposing appropriately before and after computing the permute layout.
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
@ -135,7 +135,7 @@ using inverse_t = decltype(inverse(T{}));
template<class Permute, bool Transpose, class Shape, class Stride> template<class Permute, bool Transpose, class Shape, class Stride>
constexpr auto constexpr auto
make_original_layout(Layout<Shape,Stride> const& layout) { make_original_layout(Layout<Shape,Stride> const& layout) {
static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported"); static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
if constexpr (Transpose) { if constexpr (Transpose) {
// Deal with tensor B by transposing appropriately before and after computing the permute layout. // Deal with tensor B by transposing appropriately before and after computing the permute layout.
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].

View File

@ -86,9 +86,9 @@ CUTE_DEVICE dim3 cluster_grid_dims()
{ {
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z; uint32_t x, y, z;
asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : ); asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : ); asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : ); asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : );
return {x, y, z}; return {x, y, z};
#elif defined(__CUDA_ARCH__) #elif defined(__CUDA_ARCH__)
// MSVC requires protecting use of gridDim with __CUDA_ARCH__. // MSVC requires protecting use of gridDim with __CUDA_ARCH__.
@ -105,9 +105,9 @@ CUTE_DEVICE dim3 cluster_id_in_grid()
{ {
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z; uint32_t x, y, z;
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : ); asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : ); asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : ); asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : );
return {x, y, z}; return {x, y, z};
#elif defined(__CUDA_ARCH__) #elif defined(__CUDA_ARCH__)
// MSVC requires protecting use of blockIdx with __CUDA_ARCH__. // MSVC requires protecting use of blockIdx with __CUDA_ARCH__.
@ -124,9 +124,9 @@ CUTE_DEVICE dim3 block_id_in_cluster()
{ {
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z; uint32_t x, y, z;
asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : ); asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : ); asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : ); asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : );
return {x, y, z}; return {x, y, z};
#else #else
return {0,0,0}; return {0,0,0};
@ -138,9 +138,9 @@ CUTE_DEVICE dim3 cluster_shape()
{ {
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z; uint32_t x, y, z;
asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : ); asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : ); asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : ); asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : );
return {x, y, z}; return {x, y, z};
#else #else
return {1,1,1}; return {1,1,1};
@ -152,7 +152,7 @@ CUTLASS_DEVICE uint32_t block_rank_in_cluster()
{ {
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t rank; uint32_t rank;
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :); asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :);
return rank; return rank;
#else #else
return 0; return 0;

View File

@ -30,7 +30,7 @@
**************************************************************************************************/ **************************************************************************************************/
#pragma once #pragma once
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) #if defined(__CUDACC__) || defined(_NVHPC_CUDA)
# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ # define CUTE_HOST_DEVICE __forceinline__ __host__ __device__
# define CUTE_DEVICE __forceinline__ __device__ # define CUTE_DEVICE __forceinline__ __device__
# define CUTE_HOST __forceinline__ __host__ # define CUTE_HOST __forceinline__ __host__
@ -46,10 +46,11 @@
# define CUTE_HOST_RTC CUTE_HOST # define CUTE_HOST_RTC CUTE_HOST
#endif #endif
#if !defined(__CUDACC_RTC__) && (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) #if !defined(__CUDACC_RTC__) && !defined(__clang__) && \
(defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
# define CUTE_UNROLL #pragma unroll # define CUTE_UNROLL #pragma unroll
# define CUTE_NO_UNROLL #pragma unroll 1 # define CUTE_NO_UNROLL #pragma unroll 1
#elif defined(__CUDACC_RTC__) #elif defined(__CUDACC_RTC__) || defined(__clang__)
# define CUTE_UNROLL _Pragma("unroll") # define CUTE_UNROLL _Pragma("unroll")
# define CUTE_NO_UNROLL _Pragma("unroll 1") # define CUTE_NO_UNROLL _Pragma("unroll 1")
#else #else

View File

@ -35,6 +35,7 @@
#pragma once #pragma once
#include <cstdio>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/trace.h" #include "cutlass/trace.h"

View File

@ -595,7 +595,7 @@ CollectiveBuilder<
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> { cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
private: private:
using FusionOp = using FusionOp =
fusion::LinCombEltAct<Schedule::ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>; fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
using ImplSchedule = using ImplSchedule =
cute::conditional_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule>, cute::conditional_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule>,
TmaWarpSpecialized, TmaWarpSpecializedCooperative>; TmaWarpSpecialized, TmaWarpSpecializedCooperative>;
@ -676,7 +676,7 @@ private:
using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator< using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator<
GmemStrideTypeAux, typename Schedule::ElementT>()); GmemStrideTypeAux, typename Schedule::ElementT>());
using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux< using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux<
GmemLayoutTagD, Schedule::ActivationFunctor, ElementD, ElementCompute, GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute,
typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute
>; >;
using FusionCallbacksAux = fusion::FusionCallbacks< using FusionCallbacksAux = fusion::FusionCallbacks<
@ -684,7 +684,7 @@ private:
>; >;
using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct< using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct<
Schedule::ActivationFunctor, ElementD, ElementCompute, Schedule::template ActivationFunctor, ElementD, ElementCompute,
typename Schedule::ElementBias, ElementCompute typename Schedule::ElementBias, ElementCompute
>; >;
using FusionCallbacksNoAux = fusion::FusionCallbacks< using FusionCallbacksNoAux = fusion::FusionCallbacks<

View File

@ -81,8 +81,8 @@ public:
static const int kOutputAlignment = ThreadEpilogueOp::kCount; static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type; using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { }; struct SharedStorage { };
@ -163,10 +163,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static"); static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
// Separate out problem shape for convenience // Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl); auto M = get<0>(problem_shape_mnkl);

View File

@ -204,12 +204,12 @@ public:
int thread_idx, int thread_idx,
TensorStorage& shared_tensors) TensorStorage& shared_tensors)
{ {
constexpr int BLK_M_RANK = rank<0>(tile_shape_MNK); constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK);
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) { auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
})); }));
constexpr int BLK_N_RANK = rank<1>(tile_shape_MNK); constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK);
auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) { auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) {
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
})); }));

View File

@ -91,8 +91,8 @@ public:
using StrideD = StrideD_; using StrideD = StrideD_;
using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount; static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type; using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
@ -182,10 +182,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static"); static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4");
// Separate out problem shape for convenience // Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl); auto M = get<0>(problem_shape_mnkl);

View File

@ -87,8 +87,8 @@ public:
static const int kOutputAlignment = ThreadEpilogueOp::kCount; static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type; using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage struct SharedStorage
{ {
@ -172,10 +172,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static"); static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
// synchronizing function for smem reads/writes // synchronizing function for smem reads/writes
#if CUDA_BARRIER_ENABLED #if CUDA_BARRIER_ENABLED

View File

@ -113,12 +113,12 @@ public:
using GmemTiledCopyD = SM90_TMA_STORE; using GmemTiledCopyD = SM90_TMA_STORE;
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape"); static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
static_assert(rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]");
static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M");
static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");
private: private:
using SmemElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages using SmemElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
@ -340,10 +340,10 @@ public:
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
// Tile residue // Tile residue
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(tile_shape_MNK)>{}, [&](auto i) { auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(tile_shape_MNK)>{}, [&](auto i) {
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
})); }));
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(tile_shape_MNK)>{}, [&](auto i) { auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(tile_shape_MNK)>{}, [&](auto i) {
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
})); }));
auto residue_mn = make_coord(m_max_coord, n_max_coord); auto residue_mn = make_coord(m_max_coord, n_max_coord);
@ -456,11 +456,11 @@ public:
using ElementCompute = cute::conditional_t<cute::is_void_v<ElementCompute_>,ElementAccumulator,ElementCompute_>; using ElementCompute = cute::conditional_t<cute::is_void_v<ElementCompute_>,ElementAccumulator,ElementCompute_>;
static_assert(is_rmem<AccEngine>::value, "Accumulator must be RF resident."); static_assert(is_rmem<AccEngine>::value, "Accumulator must be RF resident.");
static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); static_assert(cute::rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<TileShapeMNK>::value, "TileShapeMNK must be static"); static_assert(is_static<TileShapeMNK>::value, "TileShapeMNK must be static");
static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); static_assert(cute::rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); static_assert(cute::rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
// Indexing variables // Indexing variables
auto [M, N, K, L] = problem_shape_mnkl; auto [M, N, K, L] = problem_shape_mnkl;
@ -530,11 +530,11 @@ public:
Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
// Coordinate tensors and residue for tile quantization // Coordinate tensors and residue for tile quantization
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(CtaTileMNK{})>{}, [&](auto i) { auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(CtaTileMNK{})>{}, [&](auto i) {
auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl); auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl);
return cute::max(0, c_m); return cute::max(0, c_m);
})); }));
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(CtaTileMNK{})>{}, [&](auto i) { auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(CtaTileMNK{})>{}, [&](auto i) {
auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl); auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl);
return cute::max(0, c_n); return cute::max(0, c_n);
})); }));
@ -559,7 +559,7 @@ public:
tRS_cD, tRS_cD,
tRS_rC tRS_rC
}; };
auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks<RefSrc>(cst_args); auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);
bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed();
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

View File

@ -695,7 +695,7 @@ template<
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
> >
using Sm90ScaledLinCombPerRowBiasEltAct = using Sm90ScaledLinCombPerRowBiasEltAct =
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z) Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
@ -829,7 +829,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAux =
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>, Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>,
// D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) // D = activation(Z) * scale_d, amax_d = max(abs(elements in D))
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_d Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_d
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z) Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
Sm90SplitTreeFetch // Z Sm90SplitTreeFetch // Z
@ -839,7 +839,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAux =
>, >,
// Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux))
Sm90EVT<Sm90AuxStore<StagesD, EpilogueTile, ElementAux, RoundStyle, StrideAux, SmemLayoutAtom, CopyOpR2S, AlignmentAux>, // store(Aux) Sm90EVT<Sm90AuxStore<StagesD, EpilogueTile, ElementAux, RoundStyle, StrideAux, SmemLayoutAtom, CopyOpR2S, AlignmentAux>, // store(Aux)
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementAux>::Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementAux>::template Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux
Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_aux Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_aux
Sm90SplitTreeFetch // Z Sm90SplitTreeFetch // Z
>, >,
@ -1021,7 +1021,7 @@ template<
using Sm90LinCombDeEltAct = using Sm90LinCombDeEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux) Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux)
Sm90LinearCombination<ElementCompute, ElementCompute, ElementScalar, RoundStyle>, // beta * C + (alpha * acc) Sm90LinearCombination<ElementCompute, ElementCompute, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
Sm90AuxLoad<Stages, EpilogueTile, ElementAux, StrideAux, SmemLayoutAtom, CopyOpS2R, AlignmentAux>, // aux Sm90AuxLoad<Stages, EpilogueTile, ElementAux, StrideAux, SmemLayoutAtom, CopyOpS2R, AlignmentAux> // aux
>; >;
template < template <

View File

@ -237,6 +237,18 @@ struct Sm90TreeVisitor<
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle> Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>
>; >;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90TreeVisitor() {}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor(
Params const& params,
SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
CUTLASS_DEVICE bool CUTLASS_DEVICE bool
is_producer_load_needed() const { is_producer_load_needed() const {
auto const& bcast_op = get<0>(Impl::ops); auto const& bcast_op = get<0>(Impl::ops);
@ -252,8 +264,6 @@ struct Sm90TreeVisitor<
return bcast_op.scalar != 0 || added_op.is_C_load_needed(); return bcast_op.scalar != 0 || added_op.is_C_load_needed();
} }
using Impl::Sm90VisitorImpl;
template <class CallbacksImpl> template <class CallbacksImpl>
struct ConsumerStoreCallbacks : CallbacksImpl { struct ConsumerStoreCallbacks : CallbacksImpl {
CUTLASS_DEVICE CUTLASS_DEVICE
@ -301,10 +311,9 @@ struct Sm90TreeVisitor<
> >
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks( auto callbacks_tuple = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
is_C_load_needed(), return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(
Impl::get_consumer_store_callbacks<ReferenceSrc>(args) is_C_load_needed(), std::move(callbacks_tuple));
);
} }
}; };
@ -475,7 +484,8 @@ struct Sm90ReLUAuxStore {
gAux, args.epi_tile, args.tiled_copy, args.thread_idx); gAux, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tC_rAux = make_tensor<cutlass::uint1b_t>(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) Tensor tC_rAux = make_tensor<cutlass::uint1b_t>(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params); return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.tCcD), decltype(args.residue_mn)>(
cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params);
} }
}; };
} // namespace detail } // namespace detail
@ -532,7 +542,17 @@ struct Sm90TreeVisitor<
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle> Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
>; >;
using Impl::Sm90VisitorImpl; using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90TreeVisitor() {}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor(
Params const& params,
SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
template <class CallbacksImpl> template <class CallbacksImpl>
struct ConsumerStoreCallbacks : CallbacksImpl { struct ConsumerStoreCallbacks : CallbacksImpl {
@ -556,9 +576,8 @@ struct Sm90TreeVisitor<
> >
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks( auto callbacks_tuple = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
Impl::get_consumer_store_callbacks<ReferenceSrc>(args) return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
);
} }
}; };
@ -654,7 +673,7 @@ struct Sm90AuxLoad<
CUTLASS_DEVICE void CUTLASS_DEVICE void
begin() { begin() {
if constexpr (decltype(rank(tC_rAux))::value == 5) { if constexpr (decltype(cute::rank(tC_rAux))::value == 5) {
if constexpr (EnableNullptr) { if constexpr (EnableNullptr) {
if (params.ptr_aux == nullptr) { if (params.ptr_aux == nullptr) {
return; return;
@ -669,7 +688,7 @@ struct Sm90AuxLoad<
CUTLASS_DEVICE void CUTLASS_DEVICE void
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
if constexpr (decltype(rank(tC_rAux))::value == 3) { if constexpr (decltype(cute::rank(tC_rAux))::value == 3) {
if constexpr (EnableNullptr) { if constexpr (EnableNullptr) {
if (params.ptr_aux == nullptr) { if (params.ptr_aux == nullptr) {
return; return;
@ -686,7 +705,7 @@ struct Sm90AuxLoad<
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
using ElementRegister = typename remove_cvref_t<RTensor>::value_type; using ElementRegister = typename remove_cvref_t<RTensor>::value_type;
if constexpr (decltype(rank(tC_rAux))::value == 3) { if constexpr (decltype(cute::rank(tC_rAux))::value == 3) {
return recast<Array<ElementRegister, FragmentSize>>(coalesce(tC_rAux))(epi_v); return recast<Array<ElementRegister, FragmentSize>>(coalesce(tC_rAux))(epi_v);
} }
else { else {
@ -727,7 +746,8 @@ struct Sm90AuxLoad<
} }
} }
return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params); return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.residue_mn)>(
cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params);
} }
}; };

View File

@ -280,7 +280,8 @@ struct Sm90AuxLoad {
Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N)
Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE)
return ProducerLoadCallbacks(cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); return ProducerLoadCallbacks<decltype(bGS_gAux), decltype(bGS_sAux)>(
cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr);
} }
template <class RTensor, class TiledS2R, class STensorS2R> template <class RTensor, class TiledS2R, class STensorS2R>
@ -344,7 +345,8 @@ struct Sm90AuxLoad {
auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE)
return ConsumerStoreCallbacks(cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tiled_s2r), decltype(tSR_sAux)>(
cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr);
} }
}; };

View File

@ -268,7 +268,7 @@ struct Sm90AuxStore {
Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE)
Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks( return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tiled_r2s), decltype(tRS_sAux), decltype(bSG_sAux), decltype(bSG_gAux)>(
cute::move(tC_rAux), cute::move(tC_rAux),
tiled_r2s, tiled_r2s,
cute::move(tRS_sAux), cute::move(tRS_sAux),
@ -1109,12 +1109,11 @@ public:
Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L) Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L)
Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N) Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N)
return ConsumerStoreCallbacks( auto args_tuple = make_tuple(
make_tuple(bool_constant<ReferenceSrc>{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, bool_constant<ReferenceSrc>{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx), args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx);
params return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple), params);
);
} }
}; };

View File

@ -272,8 +272,18 @@ struct Sm90VisitorImplBase {
template <class... Ops> template <class... Ops>
struct Sm90VisitorImpl : Sm90VisitorImplBase<Ops...> { struct Sm90VisitorImpl : Sm90VisitorImplBase<Ops...> {
using Sm90VisitorImplBase<Ops...>::Sm90VisitorImplBase; using Impl = Sm90VisitorImplBase<Ops...>;
using Sm90VisitorImplBase<Ops...>::ops; using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90VisitorImpl() {}
CUTLASS_HOST_DEVICE
Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
using Impl::ops;
// //
// Queries for kernel runtime // Queries for kernel runtime
@ -506,7 +516,18 @@ using namespace detail;
template <class NodeOp, class... ChildOps> template <class NodeOp, class... ChildOps>
struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> { struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> {
using Sm90VisitorImpl<ChildOps..., NodeOp>::Sm90VisitorImpl; using Impl = Sm90VisitorImpl<ChildOps..., NodeOp>;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90TreeVisitor() {}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor(
Params const& params,
SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
template<class CallbacksImpl> template<class CallbacksImpl>
struct ConsumerStoreCallbacks : CallbacksImpl { struct ConsumerStoreCallbacks : CallbacksImpl {
@ -538,10 +559,9 @@ struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> {
> >
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks( auto callbacks_tuple = Sm90VisitorImpl<ChildOps..., NodeOp>::
Sm90VisitorImpl<ChildOps..., NodeOp>:: template get_consumer_store_callbacks<ReferenceSrc>(args);
get_consumer_store_callbacks<ReferenceSrc>(args) return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
);
} }
}; };
@ -590,10 +610,9 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputT
> >
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks( auto callbacks_tuple = Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputTree>::
Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputTree>:: template get_consumer_store_callbacks<ReferenceSrc>(args);
get_consumer_store_callbacks<ReferenceSrc>(args) return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
);
} }
}; };
@ -609,7 +628,7 @@ template<
> >
struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> { struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> {
static_assert(is_static_v<EdgeTuple>); static_assert(is_static_v<EdgeTuple>);
static_assert(rank(EdgeTuple{}) == sizeof...(Ops)); static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops));
static_assert(sizeof...(Ops) > 1); static_assert(sizeof...(Ops) > 1);
using Sm90VisitorImpl<Ops...>::Sm90VisitorImpl; using Sm90VisitorImpl<Ops...>::Sm90VisitorImpl;
@ -669,10 +688,9 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> {
> >
CUTLASS_DEVICE auto CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks( auto callbacks_tuple = Sm90VisitorImpl<Ops...>::
Sm90VisitorImpl<Ops...>:: template get_consumer_store_callbacks<ReferenceSrc>(args);
get_consumer_store_callbacks<ReferenceSrc>(args) return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
);
} }
}; };

View File

@ -232,7 +232,7 @@ template<
> >
struct TopologicalVisitor2x : VisitorImpl2x<Ops...> { struct TopologicalVisitor2x : VisitorImpl2x<Ops...> {
static_assert(is_static_v<EdgeTuple>); static_assert(is_static_v<EdgeTuple>);
static_assert(rank(EdgeTuple{}) == sizeof...(Ops)); static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops));
static_assert(sizeof...(Ops) > 1); static_assert(sizeof...(Ops) > 1);
using VisitorImpl2x<Ops...>::VisitorImpl2x; using VisitorImpl2x<Ops...>::VisitorImpl2x;

View File

@ -100,11 +100,11 @@ struct CollectiveMma<
using TransformB = TransformB_; using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag; using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -173,9 +173,9 @@ struct CollectiveMma<
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident."); static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident."); static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 2, static_assert(cute::rank(SmemLayoutA{}) == 2,
"MainloopTwoStage must not have a smem shape with a pipeline mode."); "MainloopTwoStage must not have a smem shape with a pipeline mode.");
static_assert(rank(SmemLayoutB{}) == 2, static_assert(cute::rank(SmemLayoutB{}) == 2,
"MainloopTwoStage must not have a smem shape with a pipeline mode."); "MainloopTwoStage must not have a smem shape with a pipeline mode.");
// Construct shared memory tiles // Construct shared memory tiles
@ -343,11 +343,11 @@ struct CollectiveMma<
using TransformB = TransformB_; using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag; using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -414,9 +414,9 @@ struct CollectiveMma<
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident."); static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident."); static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 2, static_assert(cute::rank(SmemLayoutA{}) == 2,
"MainloopTwoStage must not have a smem shape with a pipeline mode."); "MainloopTwoStage must not have a smem shape with a pipeline mode.");
static_assert(rank(SmemLayoutB{}) == 2, static_assert(cute::rank(SmemLayoutB{}) == 2,
"MainloopTwoStage must not have a smem shape with a pipeline mode."); "MainloopTwoStage must not have a smem shape with a pipeline mode.");
// Construct shared memory tiles // Construct shared memory tiles

View File

@ -101,11 +101,11 @@ struct CollectiveMma<
using TransformB = TransformB_; using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag; using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -174,9 +174,9 @@ struct CollectiveMma<
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident."); static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident."); static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, static_assert(cute::rank(SmemLayoutA{}) == 3,
"MainloopSm80CpAsync must have a pipeline mode in the smem layout."); "MainloopSm80CpAsync must have a pipeline mode in the smem layout.");
static_assert(rank(SmemLayoutB{}) == 3, static_assert(cute::rank(SmemLayoutB{}) == 3,
"MainloopSm80CpAsync must have a pipeline mode in the smem layout."); "MainloopSm80CpAsync must have a pipeline mode in the smem layout.");
// Construct shared memory tiles // Construct shared memory tiles
@ -390,11 +390,11 @@ struct CollectiveMma<
using TransformB = TransformB_; using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag; using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -463,8 +463,8 @@ struct CollectiveMma<
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident."); static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident."); static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
// Construct shared memory tiles // Construct shared memory tiles
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf); SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);

View File

@ -138,11 +138,11 @@ struct CollectiveMma<
using PipelineState = typename MainloopPipeline::PipelineState; using PipelineState = typename MainloopPipeline::PipelineState;
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -418,10 +418,10 @@ struct CollectiveMma<
{ {
using namespace cute; using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>, static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<InternalSmemCopyAtomB>, static_assert(cute::is_void_v<InternalSmemCopyAtomB>,

View File

@ -112,11 +112,11 @@ struct CollectiveMma<
using PipelineState = typename MainloopPipeline::PipelineState; using PipelineState = typename MainloopPipeline::PipelineState;
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -346,8 +346,8 @@ struct CollectiveMma<
using namespace cute; using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>, static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>, static_assert(cute::is_void_v<SmemCopyAtomB>,

View File

@ -141,11 +141,11 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -402,7 +402,7 @@ struct CollectiveMma<
// Prepare the TMA loads for A and B // Prepare the TMA loads for A and B
// //
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors); Tensor gA_mkl = get<0>(tiled_tensors);
@ -502,10 +502,10 @@ struct CollectiveMma<
{ {
using namespace cute; using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>, static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<InternalSmemCopyAtomB>, static_assert(cute::is_void_v<InternalSmemCopyAtomB>,

View File

@ -183,11 +183,11 @@ public:
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -443,7 +443,7 @@ public:
// Prepare the TMA loads for A and B // Prepare the TMA loads for A and B
// //
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors); Tensor gA_mkl = get<0>(tiled_tensors);
@ -541,10 +541,10 @@ public:
Params const& mainloop_params) { Params const& mainloop_params) {
using namespace cute; using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>, static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions."); "SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions.");
static_assert(cute::is_void_v<InternalSmemCopyAtomB>, static_assert(cute::is_void_v<InternalSmemCopyAtomB>,

View File

@ -113,11 +113,11 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>; using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -271,10 +271,10 @@ struct CollectiveMma<
using namespace cute; using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>, static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>, static_assert(cute::is_void_v<SmemCopyAtomB>,
@ -288,7 +288,7 @@ struct CollectiveMma<
// Prepare the TMA loads for A and B // Prepare the TMA loads for A and B
// //
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y);

View File

@ -114,11 +114,11 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -319,7 +319,7 @@ struct CollectiveMma<
// Prepare the TMA loads for A and B // Prepare the TMA loads for A and B
// //
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors); Tensor gA_mkl = get<0>(tiled_tensors);
@ -423,8 +423,8 @@ struct CollectiveMma<
using namespace cute; using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>, static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>, static_assert(cute::is_void_v<SmemCopyAtomB>,

View File

@ -115,11 +115,11 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -317,7 +317,7 @@ struct CollectiveMma<
// Prepare the TMA loads for A and B // Prepare the TMA loads for A and B
// //
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors); Tensor gA_mkl = get<0>(tiled_tensors);
@ -421,8 +421,8 @@ struct CollectiveMma<
using namespace cute; using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident."); static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>, static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>, static_assert(cute::is_void_v<SmemCopyAtomB>,

View File

@ -665,7 +665,7 @@ protected:
int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM;
int m_end = params.block_mapping.problem_size.m(); int m_end = params.block_mapping.problem_size.m();
return Mma::IteratorA( return typename Mma::IteratorA(
params.params_A, params.params_A,
ptr_A, ptr_A,
{ m_end, tile_work.k_end }, { m_end, tile_work.k_end },
@ -694,7 +694,7 @@ protected:
int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN;
int n_end = params.block_mapping.problem_size.n(); int n_end = params.block_mapping.problem_size.n();
return Mma::IteratorB( return typename Mma::IteratorB(
params.params_B, params.params_B,
ptr_B, ptr_B,
{ tile_work.k_end, n_end }, { tile_work.k_end, n_end },

View File

@ -60,7 +60,7 @@ public:
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -142,7 +142,7 @@ public:
static bool static bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
return args.mode == GemmUniversalMode::kGemm or return args.mode == GemmUniversalMode::kGemm or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
} }
static int static int
@ -159,7 +159,7 @@ public:
static dim3 static dim3
get_grid_shape(Params const& params) { get_grid_shape(Params const& params) {
int batch_count = 1; int batch_count = 1;
if constexpr (rank(ProblemShape{}) == 4) { if constexpr (cute::rank(ProblemShape{}) == 4) {
batch_count = cute::size<3>(params.problem_shape); batch_count = cute::size<3>(params.problem_shape);
} }
@ -193,10 +193,10 @@ public:
auto L = get<3>(problem_shape_MNKL); auto L = get<3>(problem_shape_MNKL);
// Preconditions // Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Get the appropriate blocks for this thread block -- potential for thread block locality // Get the appropriate blocks for this thread block -- potential for thread block locality
int thread_idx = int(threadIdx.x); int thread_idx = int(threadIdx.x);

View File

@ -80,7 +80,7 @@ public:
// Type Aliases // Type Aliases
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -169,7 +169,7 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable; return implementable;
@ -219,10 +219,10 @@ public:
#endif #endif
// Preconditions // Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
int thread_idx = int(threadIdx.x); int thread_idx = int(threadIdx.x);
int warp_idx = canonical_warp_idx_sync(); int warp_idx = canonical_warp_idx_sync();
@ -285,13 +285,13 @@ public:
params.mainloop params.mainloop
); );
constexpr int BLK_M_RANK = rank<0>(blk_shape); constexpr int BLK_M_RANK = cute::rank<0>(blk_shape);
bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl);
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) { auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {
return m_oob ? 0 : get<i>(M) - get<0,i>(blk_shape) * get<i>(m_coord); return m_oob ? 0 : get<i>(M) - get<0,i>(blk_shape) * get<i>(m_coord);
})); }));
constexpr int BLK_N_RANK = rank<1>(blk_shape); constexpr int BLK_N_RANK = cute::rank<1>(blk_shape);
bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl);
auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) { auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) {
return n_oob ? 0 : get<i>(N) - get<1,i>(blk_shape) * get<i>(n_coord); return n_oob ? 0 : get<i>(N) - get<1,i>(blk_shape) * get<i>(n_coord);

View File

@ -69,7 +69,7 @@ public:
// Type Aliases // Type Aliases
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -176,7 +176,7 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable; return implementable;
@ -318,10 +318,10 @@ public:
} (); } ();
// Preconditions // Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
@ -338,7 +338,7 @@ public:
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
// Extract out partitioned A and B. // Extract out partitioned A and B.
Tensor gA_mkl = get<0>(tiled_tensors); Tensor gA_mkl = get<0>(tiled_tensors);

View File

@ -69,7 +69,7 @@ public:
// Type Aliases // Type Aliases
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -219,7 +219,7 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable; return implementable;
@ -303,10 +303,10 @@ public:
static_assert(size<0>(TileShape{}) >= 128, static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
enum class WarpGroupRole { enum class WarpGroupRole {
@ -423,7 +423,7 @@ public:
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
// Extract out partitioned A and B. // Extract out partitioned A and B.
Tensor gA_mkl = get<0>(tiled_tensors); Tensor gA_mkl = get<0>(tiled_tensors);

View File

@ -70,7 +70,7 @@ public:
// Type Aliases // Type Aliases
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -225,7 +225,7 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable; return implementable;
@ -305,10 +305,10 @@ public:
#endif #endif
// Preconditions // Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
enum class WarpGroupRole { enum class WarpGroupRole {
Producer = 0, Producer = 0,
@ -427,7 +427,7 @@ public:
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
// Extract out partitioned A and B. // Extract out partitioned A and B.
Tensor gA_mkl = get<0>(tiled_tensors); Tensor gA_mkl = get<0>(tiled_tensors);

View File

@ -67,7 +67,7 @@ public:
// Type Aliases // Type Aliases
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -180,7 +180,7 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable; return implementable;
@ -289,10 +289,10 @@ public:
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>(); PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
// Preconditions // Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Separate out problem shape for convenience // Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)

View File

@ -67,7 +67,7 @@ public:
// Type Aliases // Type Aliases
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -200,7 +200,7 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable; return implementable;
@ -256,10 +256,10 @@ public:
} }
#endif #endif
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
/* In the Cooperative kernel, one or multiple Consumers collaborate on the same tile */ /* In the Cooperative kernel, one or multiple Consumers collaborate on the same tile */
enum class WarpGroupRole { enum class WarpGroupRole {

View File

@ -69,7 +69,7 @@ public:
// Type Aliases // Type Aliases
// //
using ProblemShape = ProblemShape_; using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>"); "ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types // Mainloop derived types
@ -212,7 +212,7 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable; return implementable;
@ -265,10 +265,10 @@ public:
#endif #endif
// Preconditions // Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
enum class WarpGroupRole { enum class WarpGroupRole {
Producer = 0, Producer = 0,

View File

@ -35,6 +35,7 @@
#pragma once #pragma once
#include "cute/layout.hpp"
#include "cutlass/gemm_coord.h" #include "cutlass/gemm_coord.h"
namespace cutlass { namespace cutlass {

View File

@ -192,7 +192,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> {
return static_cast<result_type>(intermediate); return static_cast<result_type>(intermediate);
} }
CUTLASS_DEVICE CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const { result_type operator()(source_type const &s) const {
return convert(s); return convert(s);
} }

View File

@ -193,8 +193,8 @@ struct TestbedImpl {
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static constexpr uint32_t mma_promotion_interval = 4; static constexpr uint32_t mma_promotion_interval = 4;
@ -523,9 +523,6 @@ struct TestbedImpl {
Gemm& gemm_op, Gemm& gemm_op,
typename Gemm::Arguments& arguments, typename Gemm::Arguments& arguments,
cutlass::device_memory::allocation<uint8_t>& workspace) { cutlass::device_memory::allocation<uint8_t>& workspace) {
int M = cute::size<0>(problem_size);
int N = cute::size<1>(problem_size);
int K = cute::size<2>(problem_size);
int L = 1; int L = 1;
if constexpr(cute::rank(ProblemShapeType{}) == 4) { if constexpr(cute::rank(ProblemShapeType{}) == 4) {
L = cute::size<3>(problem_size); L = cute::size<3>(problem_size);
@ -581,7 +578,7 @@ struct TestbedImpl {
cutlass::KernelHardwareInfo hw_info; cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0; hw_info.device_id = 0;
if (not profiling) { if (not profiling) {
this->sm_count = min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = this->sm_count; hw_info.sm_count = this->sm_count;
} }
else { else {
@ -1240,7 +1237,7 @@ struct Testbed3xFusionOperation {
hw_info.device_id = 0; hw_info.device_id = 0;
if (not profiling) { if (not profiling) {
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = impl_.sm_count; hw_info.sm_count = impl_.sm_count;
} }
else { else {

View File

@ -173,7 +173,7 @@ public:
HostScalarBroadcast(){} HostScalarBroadcast(){}
template<typename ProblemShapeType, typename TestBedImpl> template<typename ProblemShapeType, typename TestBedImpl>
HostScalarBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) HostScalarBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:_scalar(ElementCompute(Value)), Base(check_relative_equality) {} : Base(check_relative_equality), _scalar(ElementCompute(Value)) {}
template <class ElementAccumulator> template <class ElementAccumulator>
ElementCompute visit( ElementCompute visit(
@ -232,7 +232,7 @@ public:
HostRowBroadcast(){} HostRowBroadcast(){}
template<typename ProblemShapeType> template<typename ProblemShapeType>
HostRowBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) HostRowBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:impl_(impl), Base(check_relative_equality) { : Base(check_relative_equality), impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_N = cute::get<1>(problem_shape_MNKL); _N = cute::get<1>(problem_shape_MNKL);
_bias.resize(cutlass::Coord<1>(_N)); _bias.resize(cutlass::Coord<1>(_N));
@ -300,7 +300,7 @@ public:
HostColBroadcast(){} HostColBroadcast(){}
template<typename ProblemShapeType> template<typename ProblemShapeType>
HostColBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) HostColBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:impl_(impl), Base(check_relative_equality) { : Base(check_relative_equality), impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_M = cute::get<0>(problem_shape_MNKL); _M = cute::get<0>(problem_shape_MNKL);
_bias.resize(cutlass::Coord<1>(_M)); _bias.resize(cutlass::Coord<1>(_M));
@ -382,7 +382,7 @@ public:
HostAuxLoad(){} HostAuxLoad(){}
template<typename ProblemShapeType> template<typename ProblemShapeType>
HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:impl_(impl), Base(check_relative_equality){ : Base(check_relative_equality), impl_(impl){
auto problem_shape_NMKL = cute::append<4>(problem_size, 1); auto problem_shape_NMKL = cute::append<4>(problem_size, 1);
auto [_M, _N, K, _L] = problem_shape_NMKL; auto [_M, _N, K, _L] = problem_shape_NMKL;
auto aux_coord = cutlass::make_Coord(_M * _L, _N); auto aux_coord = cutlass::make_Coord(_M * _L, _N);
@ -513,8 +513,8 @@ public:
HostUnaryCompute(){} HostUnaryCompute(){}
template <typename ProblemShapeType, typename TestBedImpl> template <typename ProblemShapeType, typename TestBedImpl>
HostUnaryCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): HostUnaryCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
_child_0(problem_size, impl, check_relative_equality), Base(check_relative_equality),
Base(check_relative_equality) { } _child_0(problem_size, impl, check_relative_equality) { }
template <class ElementAccumulator> template <class ElementAccumulator>
ElementCompute visit( ElementCompute visit(
@ -578,8 +578,8 @@ public:
HostAuxStore(){} HostAuxStore(){}
template <typename ProblemShapeType> template <typename ProblemShapeType>
HostAuxStore(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): HostAuxStore(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
impl_(impl), Base(check_relative_equality),
Base(check_relative_equality) { impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [_M, _N, K, _L] = problem_shape_MNKL; auto [_M, _N, K, _L] = problem_shape_MNKL;
auto aux_coord = cutlass::make_Coord(_M * _L, _N); auto aux_coord = cutlass::make_Coord(_M * _L, _N);
@ -677,8 +677,8 @@ public:
HostRowReduce(){} HostRowReduce(){}
template <typename ProblemShapeType> template <typename ProblemShapeType>
HostRowReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): HostRowReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
impl_(impl), Base(check_relative_equality),
Base(check_relative_equality) { impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_N = cute::get<1>(problem_shape_MNKL); _N = cute::get<1>(problem_shape_MNKL);
_tensor_row_reduce.resize(cutlass::Coord<1>(_N)); _tensor_row_reduce.resize(cutlass::Coord<1>(_N));
@ -764,8 +764,8 @@ public:
HostColumnReduce(){} HostColumnReduce(){}
template <typename ProblemShapeType> template <typename ProblemShapeType>
HostColumnReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): HostColumnReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
impl_(impl), Base(check_relative_equality),
Base(check_relative_equality) { impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_M = cute::get<0>(problem_shape_MNKL); _M = cute::get<0>(problem_shape_MNKL);
_tensor_column_reduce.resize(cutlass::Coord<1>(_M)); _tensor_column_reduce.resize(cutlass::Coord<1>(_M));
@ -850,9 +850,8 @@ public:
HostScalarReduce(){} HostScalarReduce(){}
template <typename ProblemShapeType> template <typename ProblemShapeType>
HostScalarReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): HostScalarReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
impl_(impl), Base(check_relative_equality),
Base(check_relative_equality) { impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_tensor_scalar_reduce.resize(cutlass::Coord<1>(1)); _tensor_scalar_reduce.resize(cutlass::Coord<1>(1));
_reference_scalar_reduce.resize(cutlass::Coord<1>(1)); _reference_scalar_reduce.resize(cutlass::Coord<1>(1));
_reduce_buffer.resize(cutlass::Coord<1>(1)); _reduce_buffer.resize(cutlass::Coord<1>(1));
@ -1229,7 +1228,6 @@ public:
auto N = cute::get<1>(problem_shape_MNKL); auto N = cute::get<1>(problem_shape_MNKL);
auto K = cute::get<2>(problem_shape_MNKL); auto K = cute::get<2>(problem_shape_MNKL);
auto L = cute::get<3>(problem_shape_MNKL); auto L = cute::get<3>(problem_shape_MNKL);
auto coord_0 = cutlass::make_Coord(0);
auto A = cute::make_tensor(impl_.tensor_A.host_data(), auto A = cute::make_tensor(impl_.tensor_A.host_data(),
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
@ -1307,7 +1305,7 @@ public:
cutlass::KernelHardwareInfo hw_info; cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0; hw_info.device_id = 0;
if (not profiling) { if (not profiling) {
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = impl_.sm_count; hw_info.sm_count = impl_.sm_count;
} }
else { else {

View File

@ -158,7 +158,6 @@ struct Testbed3xTensorBroadcast {
bool use_bias) bool use_bias)
{ {
auto [M, N, K, L] = problem_shape_MNKL; auto [M, N, K, L] = problem_shape_MNKL;
auto coord_0 = cutlass::make_Coord(0);
impl_.tensor_D.sync_host(); impl_.tensor_D.sync_host();
EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0);
@ -218,7 +217,6 @@ struct Testbed3xTensorBroadcast {
auto N = cute::get<1>(problem_shape_MNKL); auto N = cute::get<1>(problem_shape_MNKL);
auto K = cute::get<2>(problem_shape_MNKL); auto K = cute::get<2>(problem_shape_MNKL);
auto L = cute::get<3>(problem_shape_MNKL); auto L = cute::get<3>(problem_shape_MNKL);
auto coord_0 = cutlass::make_Coord(0);
auto A = cute::make_tensor(impl_.tensor_A.host_data(), auto A = cute::make_tensor(impl_.tensor_A.host_data(),
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
@ -338,7 +336,7 @@ struct Testbed3xTensorBroadcast {
cutlass::KernelHardwareInfo hw_info; cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0; hw_info.device_id = 0;
if (not profiling) { if (not profiling) {
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = impl_.sm_count; hw_info.sm_count = impl_.sm_count;
} }
else { else {

View File

@ -163,7 +163,7 @@ public:
using EVTModule = HEVT< using EVTModule = HEVT<
HostAuxStore<Gemm, true>, HostAuxStore<Gemm, true>,
HEVT< HEVT<
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>, // activation(Z) * scaled_d HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>, // activation(Z) * scaled_d
HEVT< HEVT<
HostCompute<Gemm, ActivationFn>, // activation(Z) HostCompute<Gemm, ActivationFn>, // activation(Z)
HEVT< HEVT<
@ -174,11 +174,11 @@ public:
HostCompute<Gemm, cutlass::homogeneous_multiply_add>, HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
HostAccumulator<Gemm>, HostAccumulator<Gemm>,
HostColBroadcast<Gemm, ElementD>, HostColBroadcast<Gemm, ElementD>
> >
> >
>, >,
HostScalarBroadcast<Gemm, 1>, // scale_d HostScalarBroadcast<Gemm, 1> // scale_d
> >
>; >;
}; };
@ -211,26 +211,26 @@ public:
HostCompute<Gemm, cutlass::homogeneous_multiply_add>, HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
HostAccumulator<Gemm>, HostAccumulator<Gemm>,
HostColBroadcast<Gemm, ElementD>, HostColBroadcast<Gemm, ElementD>
> >
>, >,
// D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D))
HEVT< HEVT<
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>, HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>,
HEVT< HEVT<
HostScalarReduce<Gemm, amax, float>, HostScalarReduce<Gemm, amax, float>,
HEVT< HEVT<
HostCompute<Gemm, ActivationFn>, //activation(Z) * scaled_d HostCompute<Gemm, ActivationFn>, //activation(Z) * scaled_d
HostAccumulator<Gemm>, // Z HostAccumulator<Gemm> // Z
> >
>, >,
HostScalarBroadcast<Gemm, 1>, // scale_d HostScalarBroadcast<Gemm, 1> // scale_d
>, >,
// Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux))
HEVT< HEVT<
HostAuxStore<Gemm, false, ElementD, cutlass::layout::RowMajor>, HostAuxStore<Gemm, false, ElementD, cutlass::layout::RowMajor>,
HEVT< HEVT<
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>, HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>,
HEVT< HEVT<
HostScalarReduce<Gemm, amax, float>, HostScalarReduce<Gemm, amax, float>,
HostAccumulator<Gemm> HostAccumulator<Gemm>

View File

@ -126,7 +126,7 @@ gett(
cudaStream_t stream = 0) { cudaStream_t stream = 0) {
using namespace cute; using namespace cute;
static_assert(rank(ProblemShapeMNKL{}) == 4); static_assert(cute::rank(ProblemShapeMNKL{}) == 4);
auto M = get<0>(problem_shape_mnkl); auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl);
auto K = get<2>(problem_shape_mnkl); auto K = get<2>(problem_shape_mnkl);

View File

@ -431,11 +431,11 @@ void Gemm3x(
{ {
using namespace cute; using namespace cute;
static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename MainloopParams::LayoutB{})); static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
static_assert(rank(typename EpilogueParams::LayoutC{}) == rank(typename EpilogueParams::LayoutD{})); static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{})); static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) { if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) {
Layout layout_A = make_layout_rank3(mainloop_params.A); Layout layout_A = make_layout_rank3(mainloop_params.A);
Layout layout_B = make_layout_rank3(mainloop_params.B); Layout layout_B = make_layout_rank3(mainloop_params.B);
Layout layout_C = make_layout_rank3(epilogue_params.C); Layout layout_C = make_layout_rank3(epilogue_params.C);