Collection of changes to fix clang build. (#1200)

* Remove unused variables

* Qualify calls to make_fragment_? from templated base class.

Fixes clang build error.

* Add missing `#include <cstdio>`

* Various changes to fix clang compile errors.

* More changes to fix clang build.

Remaining issues:

- `params` initializer of `CollectiveEpilogue`.
- `ops` initializer of `Sm90VisitorImplBase`.
- `__usAtomicCAS` needs to be added to clang upstream.

* Fix remaining clang build issues.

* Qualify `cute::rank()` calls.

* Qualify some more calls that are otherwise ambiguous between `cute` and `std` namespace.

* Double-escape special registers in inline asm.

* small change

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Christian Sigg 2023-12-08 20:42:12 +01:00 committed by GitHub
parent f4a0216601
commit e1483d5fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 308 additions and 273 deletions

View File

@ -186,15 +186,15 @@ main(int argc, char const* argv[]) {
using ElementEpilogue = float;
// The following constexpr values set the max number of modes in each MNKL mode
constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes
constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes
constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes
constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes
static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{}));
static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{}));
static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{}));
static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{}));
static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{}));
constexpr int MaxRank_M = cute::rank(RowModeStridesA{}); // Max row modes
constexpr int MaxRank_N = cute::rank(ColModeStridesB{}); // Max column modes
constexpr int MaxRank_K = cute::rank(RedModeStridesA{}); // Max contraction modes
constexpr int MaxRank_L = cute::rank(BatModeStridesA{}); // Max batch modes
static_assert(cute::rank(RowModeStridesA{}) == cute::rank(RowModeStridesC{}));
static_assert(cute::rank(ColModeStridesB{}) == cute::rank(RowModeStridesC{}));
static_assert(cute::rank(RedModeStridesA{}) == cute::rank(RedModeStridesB{}));
static_assert(cute::rank(BatModeStridesA{}) == cute::rank(BatModeStridesC{}));
static_assert(cute::rank(BatModeStridesB{}) == cute::rank(BatModeStridesC{}));
// Parse command line to get modes, extents, and strides
cutlass::GettCommandLine cmd;

View File

@ -58,7 +58,7 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -180,7 +180,7 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
@ -288,10 +288,10 @@ public:
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)

View File

@ -86,8 +86,8 @@ public:
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { };
@ -151,10 +151,10 @@ public:
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
(void) smem_buf;
ThreadEpilogueOp epilogue_op{params.thread_params};

View File

@ -197,14 +197,14 @@ template<class ... Shapes>
auto
select_mode_shape(Shapes const & ... shapes) {
auto permuted_shapes = filter_tuple(cute::make_tuple(shapes...), [](auto shape) {
if constexpr (rank(shape) > 1) {
if constexpr (cute::rank(shape) > 1) {
return cute::make_tuple(shape);
}
else {
return cute::make_tuple();
}
});
if constexpr (rank(permuted_shapes) == 0) {
if constexpr (cute::rank(permuted_shapes) == 0) {
return get<0>(cute::make_tuple(shapes...));
}
else {
@ -251,7 +251,7 @@ auto
select_tile_shape(TileSize size, Shape const& shape)
{
static_assert(is_static<TileSize>::value, "Tile size must be static");
if constexpr (rank(Shape{}) == 0) {
if constexpr (cute::rank(Shape{}) == 0) {
return cute::make_tuple(size);
}
else {

View File

@ -78,7 +78,7 @@ reshape(Shape const& shape, TargetShape const& target_shape)
template<class Permute, bool Transpose, class Shape, class Stride>
constexpr auto
make_permute_layout(Layout<Shape,Stride> const& layout) {
static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported");
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
if constexpr (Transpose) {
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
@ -135,7 +135,7 @@ using inverse_t = decltype(inverse(T{}));
template<class Permute, bool Transpose, class Shape, class Stride>
constexpr auto
make_original_layout(Layout<Shape,Stride> const& layout) {
static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported");
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
if constexpr (Transpose) {
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].

View File

@ -86,9 +86,9 @@ CUTE_DEVICE dim3 cluster_grid_dims()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : );
asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : );
return {x, y, z};
#elif defined(__CUDA_ARCH__)
// MSVC requires protecting use of gridDim with __CUDA_ARCH__.
@ -105,9 +105,9 @@ CUTE_DEVICE dim3 cluster_id_in_grid()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : );
asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : );
return {x, y, z};
#elif defined(__CUDA_ARCH__)
// MSVC requires protecting use of blockIdx with __CUDA_ARCH__.
@ -124,9 +124,9 @@ CUTE_DEVICE dim3 block_id_in_cluster()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : );
asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : );
return {x, y, z};
#else
return {0,0,0};
@ -138,9 +138,9 @@ CUTE_DEVICE dim3 cluster_shape()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t x, y, z;
asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : );
asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : );
asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : );
asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : );
return {x, y, z};
#else
return {1,1,1};
@ -152,7 +152,7 @@ CUTLASS_DEVICE uint32_t block_rank_in_cluster()
{
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
uint32_t rank;
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :);
asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :);
return rank;
#else
return 0;

View File

@ -30,7 +30,7 @@
**************************************************************************************************/
#pragma once
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__
# define CUTE_DEVICE __forceinline__ __device__
# define CUTE_HOST __forceinline__ __host__
@ -46,10 +46,11 @@
# define CUTE_HOST_RTC CUTE_HOST
#endif
#if !defined(__CUDACC_RTC__) && (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \
(defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
# define CUTE_UNROLL #pragma unroll
# define CUTE_NO_UNROLL #pragma unroll 1
#elif defined(__CUDACC_RTC__)
#elif defined(__CUDACC_RTC__) || defined(__clang__)
# define CUTE_UNROLL _Pragma("unroll")
# define CUTE_NO_UNROLL _Pragma("unroll 1")
#else

View File

@ -35,6 +35,7 @@
#pragma once
#include <cstdio>
#include <cuda_runtime_api.h>
#include "cutlass/cutlass.h"
#include "cutlass/trace.h"

View File

@ -595,7 +595,7 @@ CollectiveBuilder<
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
private:
using FusionOp =
fusion::LinCombEltAct<Schedule::ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
using ImplSchedule =
cute::conditional_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule>,
TmaWarpSpecialized, TmaWarpSpecializedCooperative>;
@ -676,7 +676,7 @@ private:
using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator<
GmemStrideTypeAux, typename Schedule::ElementT>());
using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux<
GmemLayoutTagD, Schedule::ActivationFunctor, ElementD, ElementCompute,
GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute,
typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute
>;
using FusionCallbacksAux = fusion::FusionCallbacks<
@ -684,7 +684,7 @@ private:
>;
using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct<
Schedule::ActivationFunctor, ElementD, ElementCompute,
Schedule::template ActivationFunctor, ElementD, ElementCompute,
typename Schedule::ElementBias, ElementCompute
>;
using FusionCallbacksNoAux = fusion::FusionCallbacks<

View File

@ -81,8 +81,8 @@ public:
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { };
@ -163,10 +163,10 @@ public:
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);

View File

@ -204,12 +204,12 @@ public:
int thread_idx,
TensorStorage& shared_tensors)
{
constexpr int BLK_M_RANK = rank<0>(tile_shape_MNK);
constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK);
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
}));
constexpr int BLK_N_RANK = rank<1>(tile_shape_MNK);
constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK);
auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) {
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
}));

View File

@ -91,8 +91,8 @@ public:
using StrideD = StrideD_;
using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
@ -182,10 +182,10 @@ public:
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4");
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);

View File

@ -87,8 +87,8 @@ public:
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage
{
@ -172,10 +172,10 @@ public:
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
// synchronizing function for smem reads/writes
#if CUDA_BARRIER_ENABLED

View File

@ -113,12 +113,12 @@ public:
using GmemTiledCopyD = SM90_TMA_STORE;
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
static_assert(rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]");
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]");
static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M");
static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");
private:
using SmemElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
@ -340,10 +340,10 @@ public:
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
// Tile residue
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(tile_shape_MNK)>{}, [&](auto i) {
auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(tile_shape_MNK)>{}, [&](auto i) {
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
}));
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(tile_shape_MNK)>{}, [&](auto i) {
auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(tile_shape_MNK)>{}, [&](auto i) {
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
}));
auto residue_mn = make_coord(m_max_coord, n_max_coord);
@ -456,11 +456,11 @@ public:
using ElementCompute = cute::conditional_t<cute::is_void_v<ElementCompute_>,ElementAccumulator,ElementCompute_>;
static_assert(is_rmem<AccEngine>::value, "Accumulator must be RF resident.");
static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(cute::rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<TileShapeMNK>::value, "TileShapeMNK must be static");
static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
static_assert(cute::rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
static_assert(cute::rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
// Indexing variables
auto [M, N, K, L] = problem_shape_mnkl;
@ -530,11 +530,11 @@ public:
Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
// Coordinate tensors and residue for tile quantization
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(CtaTileMNK{})>{}, [&](auto i) {
auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(CtaTileMNK{})>{}, [&](auto i) {
auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl);
return cute::max(0, c_m);
}));
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(CtaTileMNK{})>{}, [&](auto i) {
auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(CtaTileMNK{})>{}, [&](auto i) {
auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl);
return cute::max(0, c_n);
}));
@ -559,7 +559,7 @@ public:
tRS_cD,
tRS_rC
};
auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks<RefSrc>(cst_args);
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);
bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed();
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

View File

@ -695,7 +695,7 @@ template<
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90ScaledLinCombPerRowBiasEltAct =
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
@ -829,7 +829,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAux =
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>,
// D = activation(Z) * scale_d, amax_d = max(abs(elements in D))
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_d
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
Sm90SplitTreeFetch // Z
@ -839,7 +839,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAux =
>,
// Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux))
Sm90EVT<Sm90AuxStore<StagesD, EpilogueTile, ElementAux, RoundStyle, StrideAux, SmemLayoutAtom, CopyOpR2S, AlignmentAux>, // store(Aux)
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementAux>::Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementAux>::template Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux
Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_aux
Sm90SplitTreeFetch // Z
>,
@ -1021,7 +1021,7 @@ template<
using Sm90LinCombDeEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux)
Sm90LinearCombination<ElementCompute, ElementCompute, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
Sm90AuxLoad<Stages, EpilogueTile, ElementAux, StrideAux, SmemLayoutAtom, CopyOpS2R, AlignmentAux>, // aux
Sm90AuxLoad<Stages, EpilogueTile, ElementAux, StrideAux, SmemLayoutAtom, CopyOpS2R, AlignmentAux> // aux
>;
template <

View File

@ -237,6 +237,18 @@ struct Sm90TreeVisitor<
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>
>;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90TreeVisitor() {}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor(
Params const& params,
SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
auto const& bcast_op = get<0>(Impl::ops);
@ -252,8 +264,6 @@ struct Sm90TreeVisitor<
return bcast_op.scalar != 0 || added_op.is_C_load_needed();
}
using Impl::Sm90VisitorImpl;
template <class CallbacksImpl>
struct ConsumerStoreCallbacks : CallbacksImpl {
CUTLASS_DEVICE
@ -301,10 +311,9 @@ struct Sm90TreeVisitor<
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks(
is_C_load_needed(),
Impl::get_consumer_store_callbacks<ReferenceSrc>(args)
);
auto callbacks_tuple = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(
is_C_load_needed(), std::move(callbacks_tuple));
}
};
@ -475,7 +484,8 @@ struct Sm90ReLUAuxStore {
gAux, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tC_rAux = make_tensor<cutlass::uint1b_t>(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params);
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.tCcD), decltype(args.residue_mn)>(
cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params);
}
};
} // namespace detail
@ -532,7 +542,17 @@ struct Sm90TreeVisitor<
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
>;
using Impl::Sm90VisitorImpl;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90TreeVisitor() {}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor(
Params const& params,
SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
template <class CallbacksImpl>
struct ConsumerStoreCallbacks : CallbacksImpl {
@ -556,9 +576,8 @@ struct Sm90TreeVisitor<
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks(
Impl::get_consumer_store_callbacks<ReferenceSrc>(args)
);
auto callbacks_tuple = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
}
};
@ -654,7 +673,7 @@ struct Sm90AuxLoad<
CUTLASS_DEVICE void
begin() {
if constexpr (decltype(rank(tC_rAux))::value == 5) {
if constexpr (decltype(cute::rank(tC_rAux))::value == 5) {
if constexpr (EnableNullptr) {
if (params.ptr_aux == nullptr) {
return;
@ -669,7 +688,7 @@ struct Sm90AuxLoad<
CUTLASS_DEVICE void
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
if constexpr (decltype(rank(tC_rAux))::value == 3) {
if constexpr (decltype(cute::rank(tC_rAux))::value == 3) {
if constexpr (EnableNullptr) {
if (params.ptr_aux == nullptr) {
return;
@ -686,7 +705,7 @@ struct Sm90AuxLoad<
CUTLASS_DEVICE auto
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
using ElementRegister = typename remove_cvref_t<RTensor>::value_type;
if constexpr (decltype(rank(tC_rAux))::value == 3) {
if constexpr (decltype(cute::rank(tC_rAux))::value == 3) {
return recast<Array<ElementRegister, FragmentSize>>(coalesce(tC_rAux))(epi_v);
}
else {
@ -727,7 +746,8 @@ struct Sm90AuxLoad<
}
}
return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params);
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.residue_mn)>(
cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params);
}
};

View File

@ -280,7 +280,8 @@ struct Sm90AuxLoad {
Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N)
Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE)
return ProducerLoadCallbacks(cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr);
return ProducerLoadCallbacks<decltype(bGS_gAux), decltype(bGS_sAux)>(
cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr);
}
template <class RTensor, class TiledS2R, class STensorS2R>
@ -344,7 +345,8 @@ struct Sm90AuxLoad {
auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE)
return ConsumerStoreCallbacks(cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr);
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tiled_s2r), decltype(tSR_sAux)>(
cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr);
}
};

View File

@ -268,7 +268,7 @@ struct Sm90AuxStore {
Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE)
Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks(
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tiled_r2s), decltype(tRS_sAux), decltype(bSG_sAux), decltype(bSG_gAux)>(
cute::move(tC_rAux),
tiled_r2s,
cute::move(tRS_sAux),
@ -1109,12 +1109,11 @@ public:
Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L)
Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N)
return ConsumerStoreCallbacks(
make_tuple(bool_constant<ReferenceSrc>{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx),
params
);
auto args_tuple = make_tuple(
bool_constant<ReferenceSrc>{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout,
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple), params);
}
};

View File

@ -272,8 +272,18 @@ struct Sm90VisitorImplBase {
template <class... Ops>
struct Sm90VisitorImpl : Sm90VisitorImplBase<Ops...> {
using Sm90VisitorImplBase<Ops...>::Sm90VisitorImplBase;
using Sm90VisitorImplBase<Ops...>::ops;
using Impl = Sm90VisitorImplBase<Ops...>;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90VisitorImpl() {}
CUTLASS_HOST_DEVICE
Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
using Impl::ops;
//
// Queries for kernel runtime
@ -506,7 +516,18 @@ using namespace detail;
template <class NodeOp, class... ChildOps>
struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> {
using Sm90VisitorImpl<ChildOps..., NodeOp>::Sm90VisitorImpl;
using Impl = Sm90VisitorImpl<ChildOps..., NodeOp>;
using Params = typename Impl::Params;
using SharedStorage = typename Impl::SharedStorage;
CUTLASS_HOST_DEVICE
Sm90TreeVisitor() {}
CUTLASS_HOST_DEVICE
Sm90TreeVisitor(
Params const& params,
SharedStorage const& shared_storage)
: Impl(params, shared_storage) {}
template<class CallbacksImpl>
struct ConsumerStoreCallbacks : CallbacksImpl {
@ -538,10 +559,9 @@ struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> {
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks(
Sm90VisitorImpl<ChildOps..., NodeOp>::
get_consumer_store_callbacks<ReferenceSrc>(args)
);
auto callbacks_tuple = Sm90VisitorImpl<ChildOps..., NodeOp>::
template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
}
};
@ -590,10 +610,9 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputT
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks(
Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputTree>::
get_consumer_store_callbacks<ReferenceSrc>(args)
);
auto callbacks_tuple = Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputTree>::
template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
}
};
@ -609,7 +628,7 @@ template<
>
struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> {
static_assert(is_static_v<EdgeTuple>);
static_assert(rank(EdgeTuple{}) == sizeof...(Ops));
static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops));
static_assert(sizeof...(Ops) > 1);
using Sm90VisitorImpl<Ops...>::Sm90VisitorImpl;
@ -669,10 +688,9 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> {
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
return ConsumerStoreCallbacks(
Sm90VisitorImpl<Ops...>::
get_consumer_store_callbacks<ReferenceSrc>(args)
);
auto callbacks_tuple = Sm90VisitorImpl<Ops...>::
template get_consumer_store_callbacks<ReferenceSrc>(args);
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
}
};

View File

@ -232,7 +232,7 @@ template<
>
struct TopologicalVisitor2x : VisitorImpl2x<Ops...> {
static_assert(is_static_v<EdgeTuple>);
static_assert(rank(EdgeTuple{}) == sizeof...(Ops));
static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops));
static_assert(sizeof...(Ops) > 1);
using VisitorImpl2x<Ops...>::VisitorImpl2x;

View File

@ -100,11 +100,11 @@ struct CollectiveMma<
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -173,9 +173,9 @@ struct CollectiveMma<
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 2,
static_assert(cute::rank(SmemLayoutA{}) == 2,
"MainloopTwoStage must not have a smem shape with a pipeline mode.");
static_assert(rank(SmemLayoutB{}) == 2,
static_assert(cute::rank(SmemLayoutB{}) == 2,
"MainloopTwoStage must not have a smem shape with a pipeline mode.");
// Construct shared memory tiles
@ -343,11 +343,11 @@ struct CollectiveMma<
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -414,9 +414,9 @@ struct CollectiveMma<
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 2,
static_assert(cute::rank(SmemLayoutA{}) == 2,
"MainloopTwoStage must not have a smem shape with a pipeline mode.");
static_assert(rank(SmemLayoutB{}) == 2,
static_assert(cute::rank(SmemLayoutB{}) == 2,
"MainloopTwoStage must not have a smem shape with a pipeline mode.");
// Construct shared memory tiles

View File

@ -101,11 +101,11 @@ struct CollectiveMma<
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -174,9 +174,9 @@ struct CollectiveMma<
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3,
static_assert(cute::rank(SmemLayoutA{}) == 3,
"MainloopSm80CpAsync must have a pipeline mode in the smem layout.");
static_assert(rank(SmemLayoutB{}) == 3,
static_assert(cute::rank(SmemLayoutB{}) == 3,
"MainloopSm80CpAsync must have a pipeline mode in the smem layout.");
// Construct shared memory tiles
@ -390,11 +390,11 @@ struct CollectiveMma<
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -463,8 +463,8 @@ struct CollectiveMma<
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
// Construct shared memory tiles
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);

View File

@ -138,11 +138,11 @@ struct CollectiveMma<
using PipelineState = typename MainloopPipeline::PipelineState;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -418,10 +418,10 @@ struct CollectiveMma<
{
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<InternalSmemCopyAtomB>,

View File

@ -112,11 +112,11 @@ struct CollectiveMma<
using PipelineState = typename MainloopPipeline::PipelineState;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -346,8 +346,8 @@ struct CollectiveMma<
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,

View File

@ -141,11 +141,11 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -402,7 +402,7 @@ struct CollectiveMma<
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors);
@ -502,10 +502,10 @@ struct CollectiveMma<
{
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<InternalSmemCopyAtomB>,

View File

@ -183,11 +183,11 @@ public:
using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -443,7 +443,7 @@ public:
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors);
@ -541,10 +541,10 @@ public:
Params const& mainloop_params) {
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions.");
static_assert(cute::is_void_v<InternalSmemCopyAtomB>,

View File

@ -113,11 +113,11 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -271,10 +271,10 @@ struct CollectiveMma<
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
@ -288,7 +288,7 @@ struct CollectiveMma<
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y);

View File

@ -114,11 +114,11 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -319,7 +319,7 @@ struct CollectiveMma<
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors);
@ -423,8 +423,8 @@ struct CollectiveMma<
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,

View File

@ -115,11 +115,11 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params;
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -317,7 +317,7 @@ struct CollectiveMma<
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(tiled_tensors);
@ -421,8 +421,8 @@ struct CollectiveMma<
using namespace cute;
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,

View File

@ -665,7 +665,7 @@ protected:
int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM;
int m_end = params.block_mapping.problem_size.m();
return Mma::IteratorA(
return typename Mma::IteratorA(
params.params_A,
ptr_A,
{ m_end, tile_work.k_end },
@ -694,7 +694,7 @@ protected:
int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN;
int n_end = params.block_mapping.problem_size.n();
return Mma::IteratorB(
return typename Mma::IteratorB(
params.params_B,
ptr_B,
{ tile_work.k_end, n_end },

View File

@ -60,7 +60,7 @@ public:
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -142,7 +142,7 @@ public:
static bool
can_implement(Arguments const& args) {
return args.mode == GemmUniversalMode::kGemm or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
}
static int
@ -159,7 +159,7 @@ public:
static dim3
get_grid_shape(Params const& params) {
int batch_count = 1;
if constexpr (rank(ProblemShape{}) == 4) {
if constexpr (cute::rank(ProblemShape{}) == 4) {
batch_count = cute::size<3>(params.problem_shape);
}
@ -193,10 +193,10 @@ public:
auto L = get<3>(problem_shape_MNKL);
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Get the appropriate blocks for this thread block -- potential for thread block locality
int thread_idx = int(threadIdx.x);

View File

@ -80,7 +80,7 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -169,7 +169,7 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
@ -219,10 +219,10 @@ public:
#endif
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
int thread_idx = int(threadIdx.x);
int warp_idx = canonical_warp_idx_sync();
@ -285,13 +285,13 @@ public:
params.mainloop
);
constexpr int BLK_M_RANK = rank<0>(blk_shape);
constexpr int BLK_M_RANK = cute::rank<0>(blk_shape);
bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl);
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {
return m_oob ? 0 : get<i>(M) - get<0,i>(blk_shape) * get<i>(m_coord);
}));
constexpr int BLK_N_RANK = rank<1>(blk_shape);
constexpr int BLK_N_RANK = cute::rank<1>(blk_shape);
bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl);
auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) {
return n_oob ? 0 : get<i>(N) - get<1,i>(blk_shape) * get<i>(n_coord);

View File

@ -69,7 +69,7 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -176,7 +176,7 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
@ -318,10 +318,10 @@ public:
} ();
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
@ -338,7 +338,7 @@ public:
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(tiled_tensors);

View File

@ -69,7 +69,7 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -219,7 +219,7 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
@ -303,10 +303,10 @@ public:
static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
enum class WarpGroupRole {
@ -423,7 +423,7 @@ public:
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(tiled_tensors);

View File

@ -70,7 +70,7 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -225,7 +225,7 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
@ -305,10 +305,10 @@ public:
#endif
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
enum class WarpGroupRole {
Producer = 0,
@ -427,7 +427,7 @@ public:
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(tiled_tensors);

View File

@ -67,7 +67,7 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -180,7 +180,7 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
@ -289,10 +289,10 @@ public:
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)

View File

@ -67,7 +67,7 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -200,7 +200,7 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
@ -256,10 +256,10 @@ public:
}
#endif
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
/* In the Cooperative kernel, one or multiple Consumers collaborate on the same tile */
enum class WarpGroupRole {

View File

@ -69,7 +69,7 @@ public:
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
@ -212,7 +212,7 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
return implementable;
@ -265,10 +265,10 @@ public:
#endif
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
enum class WarpGroupRole {
Producer = 0,

View File

@ -35,6 +35,7 @@
#pragma once
#include "cute/layout.hpp"
#include "cutlass/gemm_coord.h"
namespace cutlass {

View File

@ -192,7 +192,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> {
return static_cast<result_type>(intermediate);
}
CUTLASS_DEVICE
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}

View File

@ -193,8 +193,8 @@ struct TestbedImpl {
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static constexpr uint32_t mma_promotion_interval = 4;
@ -523,9 +523,6 @@ struct TestbedImpl {
Gemm& gemm_op,
typename Gemm::Arguments& arguments,
cutlass::device_memory::allocation<uint8_t>& workspace) {
int M = cute::size<0>(problem_size);
int N = cute::size<1>(problem_size);
int K = cute::size<2>(problem_size);
int L = 1;
if constexpr(cute::rank(ProblemShapeType{}) == 4) {
L = cute::size<3>(problem_size);
@ -581,7 +578,7 @@ struct TestbedImpl {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
if (not profiling) {
this->sm_count = min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = this->sm_count;
}
else {
@ -1240,7 +1237,7 @@ struct Testbed3xFusionOperation {
hw_info.device_id = 0;
if (not profiling) {
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = impl_.sm_count;
}
else {

View File

@ -173,7 +173,7 @@ public:
HostScalarBroadcast(){}
template<typename ProblemShapeType, typename TestBedImpl>
HostScalarBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:_scalar(ElementCompute(Value)), Base(check_relative_equality) {}
: Base(check_relative_equality), _scalar(ElementCompute(Value)) {}
template <class ElementAccumulator>
ElementCompute visit(
@ -232,7 +232,7 @@ public:
HostRowBroadcast(){}
template<typename ProblemShapeType>
HostRowBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:impl_(impl), Base(check_relative_equality) {
: Base(check_relative_equality), impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_N = cute::get<1>(problem_shape_MNKL);
_bias.resize(cutlass::Coord<1>(_N));
@ -300,7 +300,7 @@ public:
HostColBroadcast(){}
template<typename ProblemShapeType>
HostColBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:impl_(impl), Base(check_relative_equality) {
: Base(check_relative_equality), impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_M = cute::get<0>(problem_shape_MNKL);
_bias.resize(cutlass::Coord<1>(_M));
@ -382,7 +382,7 @@ public:
HostAuxLoad(){}
template<typename ProblemShapeType>
HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:impl_(impl), Base(check_relative_equality){
: Base(check_relative_equality), impl_(impl){
auto problem_shape_NMKL = cute::append<4>(problem_size, 1);
auto [_M, _N, K, _L] = problem_shape_NMKL;
auto aux_coord = cutlass::make_Coord(_M * _L, _N);
@ -513,8 +513,8 @@ public:
HostUnaryCompute(){}
template <typename ProblemShapeType, typename TestBedImpl>
HostUnaryCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
_child_0(problem_size, impl, check_relative_equality),
Base(check_relative_equality) { }
Base(check_relative_equality),
_child_0(problem_size, impl, check_relative_equality) { }
template <class ElementAccumulator>
ElementCompute visit(
@ -578,8 +578,8 @@ public:
HostAuxStore(){}
template <typename ProblemShapeType>
HostAuxStore(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
impl_(impl),
Base(check_relative_equality) {
Base(check_relative_equality),
impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [_M, _N, K, _L] = problem_shape_MNKL;
auto aux_coord = cutlass::make_Coord(_M * _L, _N);
@ -677,8 +677,8 @@ public:
HostRowReduce(){}
template <typename ProblemShapeType>
HostRowReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
impl_(impl),
Base(check_relative_equality) {
Base(check_relative_equality),
impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_N = cute::get<1>(problem_shape_MNKL);
_tensor_row_reduce.resize(cutlass::Coord<1>(_N));
@ -764,8 +764,8 @@ public:
HostColumnReduce(){}
template <typename ProblemShapeType>
HostColumnReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
impl_(impl),
Base(check_relative_equality) {
Base(check_relative_equality),
impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_M = cute::get<0>(problem_shape_MNKL);
_tensor_column_reduce.resize(cutlass::Coord<1>(_M));
@ -850,9 +850,8 @@ public:
HostScalarReduce(){}
template <typename ProblemShapeType>
HostScalarReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
impl_(impl),
Base(check_relative_equality) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
Base(check_relative_equality),
impl_(impl) {
_tensor_scalar_reduce.resize(cutlass::Coord<1>(1));
_reference_scalar_reduce.resize(cutlass::Coord<1>(1));
_reduce_buffer.resize(cutlass::Coord<1>(1));
@ -1229,7 +1228,6 @@ public:
auto N = cute::get<1>(problem_shape_MNKL);
auto K = cute::get<2>(problem_shape_MNKL);
auto L = cute::get<3>(problem_shape_MNKL);
auto coord_0 = cutlass::make_Coord(0);
auto A = cute::make_tensor(impl_.tensor_A.host_data(),
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
@ -1307,7 +1305,7 @@ public:
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
if (not profiling) {
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = impl_.sm_count;
}
else {

View File

@ -158,7 +158,6 @@ struct Testbed3xTensorBroadcast {
bool use_bias)
{
auto [M, N, K, L] = problem_shape_MNKL;
auto coord_0 = cutlass::make_Coord(0);
impl_.tensor_D.sync_host();
EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0);
@ -218,7 +217,6 @@ struct Testbed3xTensorBroadcast {
auto N = cute::get<1>(problem_shape_MNKL);
auto K = cute::get<2>(problem_shape_MNKL);
auto L = cute::get<3>(problem_shape_MNKL);
auto coord_0 = cutlass::make_Coord(0);
auto A = cute::make_tensor(impl_.tensor_A.host_data(),
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
@ -338,7 +336,7 @@ struct Testbed3xTensorBroadcast {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
if (not profiling) {
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = impl_.sm_count;
}
else {

View File

@ -163,7 +163,7 @@ public:
using EVTModule = HEVT<
HostAuxStore<Gemm, true>,
HEVT<
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>, // activation(Z) * scaled_d
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>, // activation(Z) * scaled_d
HEVT<
HostCompute<Gemm, ActivationFn>, // activation(Z)
HEVT<
@ -174,11 +174,11 @@ public:
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
HostAccumulator<Gemm>,
HostColBroadcast<Gemm, ElementD>,
HostColBroadcast<Gemm, ElementD>
>
>
>,
HostScalarBroadcast<Gemm, 1>, // scale_d
HostScalarBroadcast<Gemm, 1> // scale_d
>
>;
};
@ -211,26 +211,26 @@ public:
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
HostAccumulator<Gemm>,
HostColBroadcast<Gemm, ElementD>,
HostColBroadcast<Gemm, ElementD>
>
>,
// D = activation(Z) * scaled_d, amax_d = max(abs(elements in D))
HEVT<
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>,
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>,
HEVT<
HostScalarReduce<Gemm, amax, float>,
HEVT<
HostCompute<Gemm, ActivationFn>, //activation(Z) * scaled_d
HostAccumulator<Gemm>, // Z
HostAccumulator<Gemm> // Z
>
>,
HostScalarBroadcast<Gemm, 1>, // scale_d
HostScalarBroadcast<Gemm, 1> // scale_d
>,
// Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux))
HEVT<
HostAuxStore<Gemm, false, ElementD, cutlass::layout::RowMajor>,
HEVT<
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>,
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>,
HEVT<
HostScalarReduce<Gemm, amax, float>,
HostAccumulator<Gemm>

View File

@ -126,7 +126,7 @@ gett(
cudaStream_t stream = 0) {
using namespace cute;
static_assert(rank(ProblemShapeMNKL{}) == 4);
static_assert(cute::rank(ProblemShapeMNKL{}) == 4);
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto K = get<2>(problem_shape_mnkl);

View File

@ -431,11 +431,11 @@ void Gemm3x(
{
using namespace cute;
static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename MainloopParams::LayoutB{}));
static_assert(rank(typename EpilogueParams::LayoutC{}) == rank(typename EpilogueParams::LayoutD{}));
static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) {
if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) {
Layout layout_A = make_layout_rank3(mainloop_params.A);
Layout layout_B = make_layout_rank3(mainloop_params.B);
Layout layout_C = make_layout_rank3(epilogue_params.C);