Collection of changes to fix clang build. (#1200)
* Remove unused variables * Qualify calls to make_fragment_? from templated base class. Fixes clang build error. * Add missing `#include <cstdio>` * Various changes to fix clang compile errors. * More changes to fix clang build. Remaining issues: - `params` initializer of `CollectiveEpilogue`. - `ops` initializer of `Sm90VisitorImplBase`. - `__usAtomicCAS` needs to be added to clang upstream. * Fix remaining clang build issues. * Qualify `cute::rank()` calls. * Qualify some more calls that are otherwise ambiguous between `cute` and `std` namespace. * Double-escape special registers in inline asm. * small change --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
f4a0216601
commit
e1483d5fa0
@ -186,15 +186,15 @@ main(int argc, char const* argv[]) {
|
||||
using ElementEpilogue = float;
|
||||
|
||||
// The following constexpr values set the max number of modes in each MNKL mode
|
||||
constexpr int MaxRank_M = rank(RowModeStridesA{}); // Max row modes
|
||||
constexpr int MaxRank_N = rank(ColModeStridesB{}); // Max column modes
|
||||
constexpr int MaxRank_K = rank(RedModeStridesA{}); // Max contraction modes
|
||||
constexpr int MaxRank_L = rank(BatModeStridesA{}); // Max batch modes
|
||||
static_assert(rank(RowModeStridesA{}) == rank(RowModeStridesC{}));
|
||||
static_assert(rank(ColModeStridesB{}) == rank(RowModeStridesC{}));
|
||||
static_assert(rank(RedModeStridesA{}) == rank(RedModeStridesB{}));
|
||||
static_assert(rank(BatModeStridesA{}) == rank(BatModeStridesC{}));
|
||||
static_assert(rank(BatModeStridesB{}) == rank(BatModeStridesC{}));
|
||||
constexpr int MaxRank_M = cute::rank(RowModeStridesA{}); // Max row modes
|
||||
constexpr int MaxRank_N = cute::rank(ColModeStridesB{}); // Max column modes
|
||||
constexpr int MaxRank_K = cute::rank(RedModeStridesA{}); // Max contraction modes
|
||||
constexpr int MaxRank_L = cute::rank(BatModeStridesA{}); // Max batch modes
|
||||
static_assert(cute::rank(RowModeStridesA{}) == cute::rank(RowModeStridesC{}));
|
||||
static_assert(cute::rank(ColModeStridesB{}) == cute::rank(RowModeStridesC{}));
|
||||
static_assert(cute::rank(RedModeStridesA{}) == cute::rank(RedModeStridesB{}));
|
||||
static_assert(cute::rank(BatModeStridesA{}) == cute::rank(BatModeStridesC{}));
|
||||
static_assert(cute::rank(BatModeStridesB{}) == cute::rank(BatModeStridesC{}));
|
||||
|
||||
// Parse command line to get modes, extents, and strides
|
||||
cutlass::GettCommandLine cmd;
|
||||
|
||||
@ -58,7 +58,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -180,7 +180,7 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
@ -288,10 +288,10 @@ public:
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
// Preconditions
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
||||
|
||||
@ -86,8 +86,8 @@ public:
|
||||
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
@ -151,10 +151,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
|
||||
(void) smem_buf;
|
||||
ThreadEpilogueOp epilogue_op{params.thread_params};
|
||||
|
||||
@ -197,14 +197,14 @@ template<class ... Shapes>
|
||||
auto
|
||||
select_mode_shape(Shapes const & ... shapes) {
|
||||
auto permuted_shapes = filter_tuple(cute::make_tuple(shapes...), [](auto shape) {
|
||||
if constexpr (rank(shape) > 1) {
|
||||
if constexpr (cute::rank(shape) > 1) {
|
||||
return cute::make_tuple(shape);
|
||||
}
|
||||
else {
|
||||
return cute::make_tuple();
|
||||
}
|
||||
});
|
||||
if constexpr (rank(permuted_shapes) == 0) {
|
||||
if constexpr (cute::rank(permuted_shapes) == 0) {
|
||||
return get<0>(cute::make_tuple(shapes...));
|
||||
}
|
||||
else {
|
||||
@ -251,7 +251,7 @@ auto
|
||||
select_tile_shape(TileSize size, Shape const& shape)
|
||||
{
|
||||
static_assert(is_static<TileSize>::value, "Tile size must be static");
|
||||
if constexpr (rank(Shape{}) == 0) {
|
||||
if constexpr (cute::rank(Shape{}) == 0) {
|
||||
return cute::make_tuple(size);
|
||||
}
|
||||
else {
|
||||
|
||||
@ -78,7 +78,7 @@ reshape(Shape const& shape, TargetShape const& target_shape)
|
||||
template<class Permute, bool Transpose, class Shape, class Stride>
|
||||
constexpr auto
|
||||
make_permute_layout(Layout<Shape,Stride> const& layout) {
|
||||
static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
||||
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
||||
if constexpr (Transpose) {
|
||||
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
@ -135,7 +135,7 @@ using inverse_t = decltype(inverse(T{}));
|
||||
template<class Permute, bool Transpose, class Shape, class Stride>
|
||||
constexpr auto
|
||||
make_original_layout(Layout<Shape,Stride> const& layout) {
|
||||
static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
||||
static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported");
|
||||
if constexpr (Transpose) {
|
||||
// Deal with tensor B by transposing appropriately before and after computing the permute layout.
|
||||
// Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch].
|
||||
|
||||
@ -86,9 +86,9 @@ CUTE_DEVICE dim3 cluster_grid_dims()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %nclusterid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : );
|
||||
asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
// MSVC requires protecting use of gridDim with __CUDA_ARCH__.
|
||||
@ -105,9 +105,9 @@ CUTE_DEVICE dim3 cluster_id_in_grid()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : );
|
||||
asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
// MSVC requires protecting use of blockIdx with __CUDA_ARCH__.
|
||||
@ -124,9 +124,9 @@ CUTE_DEVICE dim3 block_id_in_cluster()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %cluster_ctaid.z;\n" : "=r"(z) : );
|
||||
asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return {0,0,0};
|
||||
@ -138,9 +138,9 @@ CUTE_DEVICE dim3 cluster_shape()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t x, y, z;
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %cluster_nctaid.z;\n" : "=r"(z) : );
|
||||
asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : );
|
||||
asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : );
|
||||
asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : );
|
||||
return {x, y, z};
|
||||
#else
|
||||
return {1,1,1};
|
||||
@ -152,7 +152,7 @@ CUTLASS_DEVICE uint32_t block_rank_in_cluster()
|
||||
{
|
||||
#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED)
|
||||
uint32_t rank;
|
||||
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(rank) :);
|
||||
asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :);
|
||||
return rank;
|
||||
#else
|
||||
return 0;
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)
|
||||
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
||||
# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__
|
||||
# define CUTE_DEVICE __forceinline__ __device__
|
||||
# define CUTE_HOST __forceinline__ __host__
|
||||
@ -46,10 +46,11 @@
|
||||
# define CUTE_HOST_RTC CUTE_HOST
|
||||
#endif
|
||||
|
||||
#if !defined(__CUDACC_RTC__) && (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
|
||||
#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \
|
||||
(defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA))
|
||||
# define CUTE_UNROLL #pragma unroll
|
||||
# define CUTE_NO_UNROLL #pragma unroll 1
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#elif defined(__CUDACC_RTC__) || defined(__clang__)
|
||||
# define CUTE_UNROLL _Pragma("unroll")
|
||||
# define CUTE_NO_UNROLL _Pragma("unroll 1")
|
||||
#else
|
||||
|
||||
@ -35,6 +35,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdio>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
@ -595,7 +595,7 @@ CollectiveBuilder<
|
||||
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
|
||||
private:
|
||||
using FusionOp =
|
||||
fusion::LinCombEltAct<Schedule::ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
|
||||
fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementCompute, Schedule::Round>;
|
||||
using ImplSchedule =
|
||||
cute::conditional_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule>,
|
||||
TmaWarpSpecialized, TmaWarpSpecializedCooperative>;
|
||||
@ -676,7 +676,7 @@ private:
|
||||
using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator<
|
||||
GmemStrideTypeAux, typename Schedule::ElementT>());
|
||||
using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux<
|
||||
GmemLayoutTagD, Schedule::ActivationFunctor, ElementD, ElementCompute,
|
||||
GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute,
|
||||
typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute
|
||||
>;
|
||||
using FusionCallbacksAux = fusion::FusionCallbacks<
|
||||
@ -684,7 +684,7 @@ private:
|
||||
>;
|
||||
|
||||
using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct<
|
||||
Schedule::ActivationFunctor, ElementD, ElementCompute,
|
||||
Schedule::template ActivationFunctor, ElementD, ElementCompute,
|
||||
typename Schedule::ElementBias, ElementCompute
|
||||
>;
|
||||
using FusionCallbacksNoAux = fusion::FusionCallbacks<
|
||||
|
||||
@ -81,8 +81,8 @@ public:
|
||||
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
@ -163,10 +163,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
|
||||
@ -204,12 +204,12 @@ public:
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors)
|
||||
{
|
||||
constexpr int BLK_M_RANK = rank<0>(tile_shape_MNK);
|
||||
constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK);
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {
|
||||
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
|
||||
}));
|
||||
|
||||
constexpr int BLK_N_RANK = rank<1>(tile_shape_MNK);
|
||||
constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK);
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) {
|
||||
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
|
||||
}));
|
||||
|
||||
@ -91,8 +91,8 @@ public:
|
||||
using StrideD = StrideD_;
|
||||
using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor;
|
||||
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
@ -182,10 +182,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4");
|
||||
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4");
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
|
||||
@ -87,8 +87,8 @@ public:
|
||||
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
@ -172,10 +172,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
||||
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
||||
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
||||
|
||||
// synchronizing function for smem reads/writes
|
||||
#if CUDA_BARRIER_ENABLED
|
||||
|
||||
@ -113,12 +113,12 @@ public:
|
||||
using GmemTiledCopyD = SM90_TMA_STORE;
|
||||
|
||||
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
|
||||
static_assert(rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
|
||||
static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]");
|
||||
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
|
||||
static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]");
|
||||
static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M");
|
||||
static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");
|
||||
|
||||
private:
|
||||
using SmemElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
|
||||
@ -340,10 +340,10 @@ public:
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
|
||||
|
||||
// Tile residue
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(tile_shape_MNK)>{}, [&](auto i) {
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(tile_shape_MNK)>{}, [&](auto i) {
|
||||
return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl);
|
||||
}));
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(tile_shape_MNK)>{}, [&](auto i) {
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(tile_shape_MNK)>{}, [&](auto i) {
|
||||
return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl);
|
||||
}));
|
||||
auto residue_mn = make_coord(m_max_coord, n_max_coord);
|
||||
@ -456,11 +456,11 @@ public:
|
||||
using ElementCompute = cute::conditional_t<cute::is_void_v<ElementCompute_>,ElementAccumulator,ElementCompute_>;
|
||||
|
||||
static_assert(is_rmem<AccEngine>::value, "Accumulator must be RF resident.");
|
||||
static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(cute::rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
|
||||
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
||||
static_assert(is_static<TileShapeMNK>::value, "TileShapeMNK must be static");
|
||||
static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
|
||||
static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
|
||||
static_assert(cute::rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
|
||||
static_assert(cute::rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
|
||||
|
||||
// Indexing variables
|
||||
auto [M, N, K, L] = problem_shape_mnkl;
|
||||
@ -530,11 +530,11 @@ public:
|
||||
Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
|
||||
|
||||
// Coordinate tensors and residue for tile quantization
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<rank<0>(CtaTileMNK{})>{}, [&](auto i) {
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<cute::rank<0>(CtaTileMNK{})>{}, [&](auto i) {
|
||||
auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl);
|
||||
return cute::max(0, c_m);
|
||||
}));
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<rank<1>(CtaTileMNK{})>{}, [&](auto i) {
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<cute::rank<1>(CtaTileMNK{})>{}, [&](auto i) {
|
||||
auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl);
|
||||
return cute::max(0, c_n);
|
||||
}));
|
||||
@ -559,7 +559,7 @@ public:
|
||||
tRS_cD,
|
||||
tRS_rC
|
||||
};
|
||||
auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks<RefSrc>(cst_args);
|
||||
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);
|
||||
bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed();
|
||||
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
|
||||
|
||||
|
||||
@ -695,7 +695,7 @@ template<
|
||||
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
using Sm90ScaledLinCombPerRowBiasEltAct =
|
||||
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
|
||||
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
|
||||
// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias
|
||||
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
|
||||
@ -829,7 +829,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAux =
|
||||
// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias
|
||||
Sm90ScaledLinCombPerRowBias<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>,
|
||||
// D = activation(Z) * scale_d, amax_d = max(abs(elements in D))
|
||||
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
|
||||
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementOutput>::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d
|
||||
Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_d
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementCompute, ElementCompute, RoundStyle>, // activation(Z)
|
||||
Sm90SplitTreeFetch // Z
|
||||
@ -839,7 +839,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAux =
|
||||
>,
|
||||
// Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux))
|
||||
Sm90EVT<Sm90AuxStore<StagesD, EpilogueTile, ElementAux, RoundStyle, StrideAux, SmemLayoutAtom, CopyOpR2S, AlignmentAux>, // store(Aux)
|
||||
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementAux>::Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux
|
||||
Sm90EVT<Sm90Compute<detail::ScaleOutOp<ElementAux>::template Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux
|
||||
Sm90EVT<Sm90ScalarReduction<detail::amax, atomic_maximum, ElementAmax, ElementCompute, RoundStyle>, // amax_aux
|
||||
Sm90SplitTreeFetch // Z
|
||||
>,
|
||||
@ -1021,7 +1021,7 @@ template<
|
||||
using Sm90LinCombDeEltAct =
|
||||
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux)
|
||||
Sm90LinearCombination<ElementCompute, ElementCompute, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
|
||||
Sm90AuxLoad<Stages, EpilogueTile, ElementAux, StrideAux, SmemLayoutAtom, CopyOpS2R, AlignmentAux>, // aux
|
||||
Sm90AuxLoad<Stages, EpilogueTile, ElementAux, StrideAux, SmemLayoutAtom, CopyOpS2R, AlignmentAux> // aux
|
||||
>;
|
||||
|
||||
template <
|
||||
|
||||
@ -237,6 +237,18 @@ struct Sm90TreeVisitor<
|
||||
Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>
|
||||
>;
|
||||
|
||||
using Params = typename Impl::Params;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor(
|
||||
Params const& params,
|
||||
SharedStorage const& shared_storage)
|
||||
: Impl(params, shared_storage) {}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
auto const& bcast_op = get<0>(Impl::ops);
|
||||
@ -252,8 +264,6 @@ struct Sm90TreeVisitor<
|
||||
return bcast_op.scalar != 0 || added_op.is_C_load_needed();
|
||||
}
|
||||
|
||||
using Impl::Sm90VisitorImpl;
|
||||
|
||||
template <class CallbacksImpl>
|
||||
struct ConsumerStoreCallbacks : CallbacksImpl {
|
||||
CUTLASS_DEVICE
|
||||
@ -301,10 +311,9 @@ struct Sm90TreeVisitor<
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
return ConsumerStoreCallbacks(
|
||||
is_C_load_needed(),
|
||||
Impl::get_consumer_store_callbacks<ReferenceSrc>(args)
|
||||
);
|
||||
auto callbacks_tuple = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(
|
||||
is_C_load_needed(), std::move(callbacks_tuple));
|
||||
}
|
||||
};
|
||||
|
||||
@ -475,7 +484,8 @@ struct Sm90ReLUAuxStore {
|
||||
gAux, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tC_rAux = make_tensor<cutlass::uint1b_t>(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params);
|
||||
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.tCcD), decltype(args.residue_mn)>(
|
||||
cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params);
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
@ -532,7 +542,17 @@ struct Sm90TreeVisitor<
|
||||
Sm90Compute<Activation, ElementOutput, ElementCompute, RoundStyle>
|
||||
>;
|
||||
|
||||
using Impl::Sm90VisitorImpl;
|
||||
using Params = typename Impl::Params;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor(
|
||||
Params const& params,
|
||||
SharedStorage const& shared_storage)
|
||||
: Impl(params, shared_storage) {}
|
||||
|
||||
template <class CallbacksImpl>
|
||||
struct ConsumerStoreCallbacks : CallbacksImpl {
|
||||
@ -556,9 +576,8 @@ struct Sm90TreeVisitor<
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
return ConsumerStoreCallbacks(
|
||||
Impl::get_consumer_store_callbacks<ReferenceSrc>(args)
|
||||
);
|
||||
auto callbacks_tuple = Impl::template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
|
||||
}
|
||||
|
||||
};
|
||||
@ -654,7 +673,7 @@ struct Sm90AuxLoad<
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if constexpr (decltype(rank(tC_rAux))::value == 5) {
|
||||
if constexpr (decltype(cute::rank(tC_rAux))::value == 5) {
|
||||
if constexpr (EnableNullptr) {
|
||||
if (params.ptr_aux == nullptr) {
|
||||
return;
|
||||
@ -669,7 +688,7 @@ struct Sm90AuxLoad<
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
|
||||
if constexpr (decltype(rank(tC_rAux))::value == 3) {
|
||||
if constexpr (decltype(cute::rank(tC_rAux))::value == 3) {
|
||||
if constexpr (EnableNullptr) {
|
||||
if (params.ptr_aux == nullptr) {
|
||||
return;
|
||||
@ -686,7 +705,7 @@ struct Sm90AuxLoad<
|
||||
CUTLASS_DEVICE auto
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
using ElementRegister = typename remove_cvref_t<RTensor>::value_type;
|
||||
if constexpr (decltype(rank(tC_rAux))::value == 3) {
|
||||
if constexpr (decltype(cute::rank(tC_rAux))::value == 3) {
|
||||
return recast<Array<ElementRegister, FragmentSize>>(coalesce(tC_rAux))(epi_v);
|
||||
}
|
||||
else {
|
||||
@ -727,7 +746,8 @@ struct Sm90AuxLoad<
|
||||
}
|
||||
}
|
||||
|
||||
return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params);
|
||||
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tC_gAux), decltype(args.residue_mn)>(
|
||||
cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -280,7 +280,8 @@ struct Sm90AuxLoad {
|
||||
Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N)
|
||||
Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE)
|
||||
|
||||
return ProducerLoadCallbacks(cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr);
|
||||
return ProducerLoadCallbacks<decltype(bGS_gAux), decltype(bGS_sAux)>(
|
||||
cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr);
|
||||
}
|
||||
|
||||
template <class RTensor, class TiledS2R, class STensorS2R>
|
||||
@ -344,7 +345,8 @@ struct Sm90AuxLoad {
|
||||
auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE)
|
||||
|
||||
|
||||
return ConsumerStoreCallbacks(cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr);
|
||||
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tiled_s2r), decltype(tSR_sAux)>(
|
||||
cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -268,7 +268,7 @@ struct Sm90AuxStore {
|
||||
Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE)
|
||||
Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N)
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
return ConsumerStoreCallbacks<decltype(tC_rAux), decltype(tiled_r2s), decltype(tRS_sAux), decltype(bSG_sAux), decltype(bSG_gAux)>(
|
||||
cute::move(tC_rAux),
|
||||
tiled_r2s,
|
||||
cute::move(tRS_sAux),
|
||||
@ -1109,12 +1109,11 @@ public:
|
||||
Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L)
|
||||
Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N)
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
make_tuple(bool_constant<ReferenceSrc>{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx),
|
||||
params
|
||||
);
|
||||
auto args_tuple = make_tuple(
|
||||
bool_constant<ReferenceSrc>{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout,
|
||||
lane_layout_MN, lane_mn, warp_layout_MN, warp_mn,
|
||||
args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple), params);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -272,8 +272,18 @@ struct Sm90VisitorImplBase {
|
||||
template <class... Ops>
|
||||
struct Sm90VisitorImpl : Sm90VisitorImplBase<Ops...> {
|
||||
|
||||
using Sm90VisitorImplBase<Ops...>::Sm90VisitorImplBase;
|
||||
using Sm90VisitorImplBase<Ops...>::ops;
|
||||
using Impl = Sm90VisitorImplBase<Ops...>;
|
||||
using Params = typename Impl::Params;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90VisitorImpl() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage)
|
||||
: Impl(params, shared_storage) {}
|
||||
|
||||
using Impl::ops;
|
||||
|
||||
//
|
||||
// Queries for kernel runtime
|
||||
@ -506,7 +516,18 @@ using namespace detail;
|
||||
template <class NodeOp, class... ChildOps>
|
||||
struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> {
|
||||
|
||||
using Sm90VisitorImpl<ChildOps..., NodeOp>::Sm90VisitorImpl;
|
||||
using Impl = Sm90VisitorImpl<ChildOps..., NodeOp>;
|
||||
using Params = typename Impl::Params;
|
||||
using SharedStorage = typename Impl::SharedStorage;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90TreeVisitor(
|
||||
Params const& params,
|
||||
SharedStorage const& shared_storage)
|
||||
: Impl(params, shared_storage) {}
|
||||
|
||||
template<class CallbacksImpl>
|
||||
struct ConsumerStoreCallbacks : CallbacksImpl {
|
||||
@ -538,10 +559,9 @@ struct Sm90TreeVisitor : Sm90VisitorImpl<ChildOps..., NodeOp> {
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
return ConsumerStoreCallbacks(
|
||||
Sm90VisitorImpl<ChildOps..., NodeOp>::
|
||||
get_consumer_store_callbacks<ReferenceSrc>(args)
|
||||
);
|
||||
auto callbacks_tuple = Sm90VisitorImpl<ChildOps..., NodeOp>::
|
||||
template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
|
||||
}
|
||||
|
||||
};
|
||||
@ -590,10 +610,9 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputT
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
return ConsumerStoreCallbacks(
|
||||
Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputTree>::
|
||||
get_consumer_store_callbacks<ReferenceSrc>(args)
|
||||
);
|
||||
auto callbacks_tuple = Sm90VisitorImpl<InputTree, AuxOutTrees..., OutputTree>::
|
||||
template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
|
||||
}
|
||||
|
||||
};
|
||||
@ -609,7 +628,7 @@ template<
|
||||
>
|
||||
struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> {
|
||||
static_assert(is_static_v<EdgeTuple>);
|
||||
static_assert(rank(EdgeTuple{}) == sizeof...(Ops));
|
||||
static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops));
|
||||
static_assert(sizeof...(Ops) > 1);
|
||||
|
||||
using Sm90VisitorImpl<Ops...>::Sm90VisitorImpl;
|
||||
@ -669,10 +688,9 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl<Ops...> {
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
return ConsumerStoreCallbacks(
|
||||
Sm90VisitorImpl<Ops...>::
|
||||
get_consumer_store_callbacks<ReferenceSrc>(args)
|
||||
);
|
||||
auto callbacks_tuple = Sm90VisitorImpl<Ops...>::
|
||||
template get_consumer_store_callbacks<ReferenceSrc>(args);
|
||||
return ConsumerStoreCallbacks<decltype(callbacks_tuple)>(std::move(callbacks_tuple));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
@ -232,7 +232,7 @@ template<
|
||||
>
|
||||
struct TopologicalVisitor2x : VisitorImpl2x<Ops...> {
|
||||
static_assert(is_static_v<EdgeTuple>);
|
||||
static_assert(rank(EdgeTuple{}) == sizeof...(Ops));
|
||||
static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops));
|
||||
static_assert(sizeof...(Ops) > 1);
|
||||
|
||||
using VisitorImpl2x<Ops...>::VisitorImpl2x;
|
||||
|
||||
@ -100,11 +100,11 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -173,9 +173,9 @@ struct CollectiveMma<
|
||||
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
|
||||
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 2,
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 2,
|
||||
"MainloopTwoStage must not have a smem shape with a pipeline mode.");
|
||||
static_assert(rank(SmemLayoutB{}) == 2,
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 2,
|
||||
"MainloopTwoStage must not have a smem shape with a pipeline mode.");
|
||||
|
||||
// Construct shared memory tiles
|
||||
@ -343,11 +343,11 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -414,9 +414,9 @@ struct CollectiveMma<
|
||||
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
|
||||
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 2,
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 2,
|
||||
"MainloopTwoStage must not have a smem shape with a pipeline mode.");
|
||||
static_assert(rank(SmemLayoutB{}) == 2,
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 2,
|
||||
"MainloopTwoStage must not have a smem shape with a pipeline mode.");
|
||||
|
||||
// Construct shared memory tiles
|
||||
|
||||
@ -101,11 +101,11 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -174,9 +174,9 @@ struct CollectiveMma<
|
||||
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
|
||||
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3,
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3,
|
||||
"MainloopSm80CpAsync must have a pipeline mode in the smem layout.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3,
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3,
|
||||
"MainloopSm80CpAsync must have a pipeline mode in the smem layout.");
|
||||
|
||||
// Construct shared memory tiles
|
||||
@ -390,11 +390,11 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -463,8 +463,8 @@ struct CollectiveMma<
|
||||
static_assert(is_gmem<TensorA>::value, "A tensor must be gmem resident.");
|
||||
static_assert(is_gmem<TensorB>::value, "B tensor must be gmem resident.");
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
|
||||
// Construct shared memory tiles
|
||||
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
||||
|
||||
@ -138,11 +138,11 @@ struct CollectiveMma<
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -418,10 +418,10 @@ struct CollectiveMma<
|
||||
{
|
||||
using namespace cute;
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
|
||||
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
|
||||
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<InternalSmemCopyAtomB>,
|
||||
|
||||
@ -112,11 +112,11 @@ struct CollectiveMma<
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -346,8 +346,8 @@ struct CollectiveMma<
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
|
||||
@ -141,11 +141,11 @@ struct CollectiveMma<
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -402,7 +402,7 @@ struct CollectiveMma<
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
@ -502,10 +502,10 @@ struct CollectiveMma<
|
||||
{
|
||||
using namespace cute;
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
|
||||
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
|
||||
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<InternalSmemCopyAtomB>,
|
||||
|
||||
@ -183,11 +183,11 @@ public:
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -443,7 +443,7 @@ public:
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
@ -541,10 +541,10 @@ public:
|
||||
Params const& mainloop_params) {
|
||||
using namespace cute;
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
|
||||
static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2.");
|
||||
static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2.");
|
||||
static_assert(!cute::is_void_v<InternalSmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions.");
|
||||
static_assert(cute::is_void_v<InternalSmemCopyAtomB>,
|
||||
|
||||
@ -113,11 +113,11 @@ struct CollectiveMma<
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -271,10 +271,10 @@ struct CollectiveMma<
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2.");
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2.");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
@ -288,7 +288,7 @@ struct CollectiveMma<
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
|
||||
@ -114,11 +114,11 @@ struct CollectiveMma<
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -319,7 +319,7 @@ struct CollectiveMma<
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
|
||||
constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
@ -423,8 +423,8 @@ struct CollectiveMma<
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
|
||||
@ -115,11 +115,11 @@ struct CollectiveMma<
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
@ -317,7 +317,7 @@ struct CollectiveMma<
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape());
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
@ -421,8 +421,8 @@ struct CollectiveMma<
|
||||
using namespace cute;
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
|
||||
@ -665,7 +665,7 @@ protected:
|
||||
|
||||
int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM;
|
||||
int m_end = params.block_mapping.problem_size.m();
|
||||
return Mma::IteratorA(
|
||||
return typename Mma::IteratorA(
|
||||
params.params_A,
|
||||
ptr_A,
|
||||
{ m_end, tile_work.k_end },
|
||||
@ -694,7 +694,7 @@ protected:
|
||||
|
||||
int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN;
|
||||
int n_end = params.block_mapping.problem_size.n();
|
||||
return Mma::IteratorB(
|
||||
return typename Mma::IteratorB(
|
||||
params.params_B,
|
||||
ptr_B,
|
||||
{ tile_work.k_end, n_end },
|
||||
|
||||
@ -60,7 +60,7 @@ public:
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -142,7 +142,7 @@ public:
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
return args.mode == GemmUniversalMode::kGemm or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
}
|
||||
|
||||
static int
|
||||
@ -159,7 +159,7 @@ public:
|
||||
static dim3
|
||||
get_grid_shape(Params const& params) {
|
||||
int batch_count = 1;
|
||||
if constexpr (rank(ProblemShape{}) == 4) {
|
||||
if constexpr (cute::rank(ProblemShape{}) == 4) {
|
||||
batch_count = cute::size<3>(params.problem_shape);
|
||||
}
|
||||
|
||||
@ -193,10 +193,10 @@ public:
|
||||
auto L = get<3>(problem_shape_MNKL);
|
||||
|
||||
// Preconditions
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
||||
int thread_idx = int(threadIdx.x);
|
||||
|
||||
@ -80,7 +80,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -169,7 +169,7 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
@ -219,10 +219,10 @@ public:
|
||||
#endif
|
||||
|
||||
// Preconditions
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
@ -285,13 +285,13 @@ public:
|
||||
params.mainloop
|
||||
);
|
||||
|
||||
constexpr int BLK_M_RANK = rank<0>(blk_shape);
|
||||
constexpr int BLK_M_RANK = cute::rank<0>(blk_shape);
|
||||
bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl);
|
||||
auto m_max_coord = unwrap(cute::transform(make_seq<BLK_M_RANK>{}, [&](auto i) {
|
||||
return m_oob ? 0 : get<i>(M) - get<0,i>(blk_shape) * get<i>(m_coord);
|
||||
}));
|
||||
|
||||
constexpr int BLK_N_RANK = rank<1>(blk_shape);
|
||||
constexpr int BLK_N_RANK = cute::rank<1>(blk_shape);
|
||||
bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl);
|
||||
auto n_max_coord = unwrap(cute::transform(make_seq<BLK_N_RANK>{}, [&](auto i) {
|
||||
return n_oob ? 0 : get<i>(N) - get<1,i>(blk_shape) * get<i>(n_coord);
|
||||
|
||||
@ -69,7 +69,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -176,7 +176,7 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
@ -318,10 +318,10 @@ public:
|
||||
} ();
|
||||
|
||||
// Preconditions
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
@ -338,7 +338,7 @@ public:
|
||||
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
|
||||
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
|
||||
static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
|
||||
@ -69,7 +69,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -219,7 +219,7 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
@ -303,10 +303,10 @@ public:
|
||||
static_assert(size<0>(TileShape{}) >= 128,
|
||||
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
|
||||
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
|
||||
enum class WarpGroupRole {
|
||||
@ -423,7 +423,7 @@ public:
|
||||
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
|
||||
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
|
||||
static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
|
||||
@ -70,7 +70,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -225,7 +225,7 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
@ -305,10 +305,10 @@ public:
|
||||
#endif
|
||||
|
||||
// Preconditions
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
@ -427,7 +427,7 @@ public:
|
||||
// get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l)
|
||||
// get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape);
|
||||
static_assert(tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
|
||||
static_assert(cute::tuple_size_v<decltype(tiled_tensors)> >= 2, "Output of tile_input_tensors must have at least two elements (A, B)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(tiled_tensors);
|
||||
|
||||
@ -67,7 +67,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -180,7 +180,7 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
@ -289,10 +289,10 @@ public:
|
||||
PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
// Preconditions
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
||||
|
||||
@ -67,7 +67,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -200,7 +200,7 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
@ -256,10 +256,10 @@ public:
|
||||
}
|
||||
#endif
|
||||
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
/* In the Cooperative kernel, one or multiple Consumers collaborate on the same tile */
|
||||
enum class WarpGroupRole {
|
||||
|
||||
@ -69,7 +69,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
|
||||
// Mainloop derived types
|
||||
@ -212,7 +212,7 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
||||
(args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n");
|
||||
return implementable;
|
||||
@ -265,10 +265,10 @@ public:
|
||||
#endif
|
||||
|
||||
// Preconditions
|
||||
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
|
||||
@ -35,6 +35,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -192,7 +192,7 @@ struct NumericConverter<int8_t, float, FloatRoundStyle::round_to_nearest> {
|
||||
return static_cast<result_type>(intermediate);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
|
||||
@ -193,8 +193,8 @@ struct TestbedImpl {
|
||||
|
||||
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
|
||||
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
static constexpr uint32_t mma_promotion_interval = 4;
|
||||
|
||||
@ -523,9 +523,6 @@ struct TestbedImpl {
|
||||
Gemm& gemm_op,
|
||||
typename Gemm::Arguments& arguments,
|
||||
cutlass::device_memory::allocation<uint8_t>& workspace) {
|
||||
int M = cute::size<0>(problem_size);
|
||||
int N = cute::size<1>(problem_size);
|
||||
int K = cute::size<2>(problem_size);
|
||||
int L = 1;
|
||||
if constexpr(cute::rank(ProblemShapeType{}) == 4) {
|
||||
L = cute::size<3>(problem_size);
|
||||
@ -581,7 +578,7 @@ struct TestbedImpl {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
if (not profiling) {
|
||||
this->sm_count = min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
hw_info.sm_count = this->sm_count;
|
||||
}
|
||||
else {
|
||||
@ -1240,7 +1237,7 @@ struct Testbed3xFusionOperation {
|
||||
|
||||
hw_info.device_id = 0;
|
||||
if (not profiling) {
|
||||
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
hw_info.sm_count = impl_.sm_count;
|
||||
}
|
||||
else {
|
||||
|
||||
@ -173,7 +173,7 @@ public:
|
||||
HostScalarBroadcast(){}
|
||||
template<typename ProblemShapeType, typename TestBedImpl>
|
||||
HostScalarBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
|
||||
:_scalar(ElementCompute(Value)), Base(check_relative_equality) {}
|
||||
: Base(check_relative_equality), _scalar(ElementCompute(Value)) {}
|
||||
|
||||
template <class ElementAccumulator>
|
||||
ElementCompute visit(
|
||||
@ -232,7 +232,7 @@ public:
|
||||
HostRowBroadcast(){}
|
||||
template<typename ProblemShapeType>
|
||||
HostRowBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
|
||||
:impl_(impl), Base(check_relative_equality) {
|
||||
: Base(check_relative_equality), impl_(impl) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
_N = cute::get<1>(problem_shape_MNKL);
|
||||
_bias.resize(cutlass::Coord<1>(_N));
|
||||
@ -300,7 +300,7 @@ public:
|
||||
HostColBroadcast(){}
|
||||
template<typename ProblemShapeType>
|
||||
HostColBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
|
||||
:impl_(impl), Base(check_relative_equality) {
|
||||
: Base(check_relative_equality), impl_(impl) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
_M = cute::get<0>(problem_shape_MNKL);
|
||||
_bias.resize(cutlass::Coord<1>(_M));
|
||||
@ -382,7 +382,7 @@ public:
|
||||
HostAuxLoad(){}
|
||||
template<typename ProblemShapeType>
|
||||
HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
|
||||
:impl_(impl), Base(check_relative_equality){
|
||||
: Base(check_relative_equality), impl_(impl){
|
||||
auto problem_shape_NMKL = cute::append<4>(problem_size, 1);
|
||||
auto [_M, _N, K, _L] = problem_shape_NMKL;
|
||||
auto aux_coord = cutlass::make_Coord(_M * _L, _N);
|
||||
@ -513,8 +513,8 @@ public:
|
||||
HostUnaryCompute(){}
|
||||
template <typename ProblemShapeType, typename TestBedImpl>
|
||||
HostUnaryCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
|
||||
_child_0(problem_size, impl, check_relative_equality),
|
||||
Base(check_relative_equality) { }
|
||||
Base(check_relative_equality),
|
||||
_child_0(problem_size, impl, check_relative_equality) { }
|
||||
|
||||
template <class ElementAccumulator>
|
||||
ElementCompute visit(
|
||||
@ -578,8 +578,8 @@ public:
|
||||
HostAuxStore(){}
|
||||
template <typename ProblemShapeType>
|
||||
HostAuxStore(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
|
||||
impl_(impl),
|
||||
Base(check_relative_equality) {
|
||||
Base(check_relative_equality),
|
||||
impl_(impl) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
auto [_M, _N, K, _L] = problem_shape_MNKL;
|
||||
auto aux_coord = cutlass::make_Coord(_M * _L, _N);
|
||||
@ -677,8 +677,8 @@ public:
|
||||
HostRowReduce(){}
|
||||
template <typename ProblemShapeType>
|
||||
HostRowReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
|
||||
impl_(impl),
|
||||
Base(check_relative_equality) {
|
||||
Base(check_relative_equality),
|
||||
impl_(impl) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
_N = cute::get<1>(problem_shape_MNKL);
|
||||
_tensor_row_reduce.resize(cutlass::Coord<1>(_N));
|
||||
@ -764,8 +764,8 @@ public:
|
||||
HostColumnReduce(){}
|
||||
template <typename ProblemShapeType>
|
||||
HostColumnReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
|
||||
impl_(impl),
|
||||
Base(check_relative_equality) {
|
||||
Base(check_relative_equality),
|
||||
impl_(impl) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
_M = cute::get<0>(problem_shape_MNKL);
|
||||
_tensor_column_reduce.resize(cutlass::Coord<1>(_M));
|
||||
@ -850,9 +850,8 @@ public:
|
||||
HostScalarReduce(){}
|
||||
template <typename ProblemShapeType>
|
||||
HostScalarReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
|
||||
impl_(impl),
|
||||
Base(check_relative_equality) {
|
||||
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
||||
Base(check_relative_equality),
|
||||
impl_(impl) {
|
||||
_tensor_scalar_reduce.resize(cutlass::Coord<1>(1));
|
||||
_reference_scalar_reduce.resize(cutlass::Coord<1>(1));
|
||||
_reduce_buffer.resize(cutlass::Coord<1>(1));
|
||||
@ -1229,7 +1228,6 @@ public:
|
||||
auto N = cute::get<1>(problem_shape_MNKL);
|
||||
auto K = cute::get<2>(problem_shape_MNKL);
|
||||
auto L = cute::get<3>(problem_shape_MNKL);
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
|
||||
auto A = cute::make_tensor(impl_.tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
|
||||
@ -1307,7 +1305,7 @@ public:
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
if (not profiling) {
|
||||
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
hw_info.sm_count = impl_.sm_count;
|
||||
}
|
||||
else {
|
||||
|
||||
@ -158,7 +158,6 @@ struct Testbed3xTensorBroadcast {
|
||||
bool use_bias)
|
||||
{
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
|
||||
impl_.tensor_D.sync_host();
|
||||
EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0);
|
||||
@ -218,7 +217,6 @@ struct Testbed3xTensorBroadcast {
|
||||
auto N = cute::get<1>(problem_shape_MNKL);
|
||||
auto K = cute::get<2>(problem_shape_MNKL);
|
||||
auto L = cute::get<3>(problem_shape_MNKL);
|
||||
auto coord_0 = cutlass::make_Coord(0);
|
||||
|
||||
auto A = cute::make_tensor(impl_.tensor_A.host_data(),
|
||||
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
|
||||
@ -338,7 +336,7 @@ struct Testbed3xTensorBroadcast {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
if (not profiling) {
|
||||
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
||||
hw_info.sm_count = impl_.sm_count;
|
||||
}
|
||||
else {
|
||||
|
||||
@ -163,7 +163,7 @@ public:
|
||||
using EVTModule = HEVT<
|
||||
HostAuxStore<Gemm, true>,
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>, // activation(Z) * scaled_d
|
||||
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>, // activation(Z) * scaled_d
|
||||
HEVT<
|
||||
HostCompute<Gemm, ActivationFn>, // activation(Z)
|
||||
HEVT<
|
||||
@ -174,11 +174,11 @@ public:
|
||||
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
|
||||
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
|
||||
HostAccumulator<Gemm>,
|
||||
HostColBroadcast<Gemm, ElementD>,
|
||||
HostColBroadcast<Gemm, ElementD>
|
||||
>
|
||||
>
|
||||
>,
|
||||
HostScalarBroadcast<Gemm, 1>, // scale_d
|
||||
HostScalarBroadcast<Gemm, 1> // scale_d
|
||||
>
|
||||
>;
|
||||
};
|
||||
@ -211,26 +211,26 @@ public:
|
||||
HostCompute<Gemm, cutlass::homogeneous_multiply_add>,
|
||||
HostScalarBroadcast<Gemm, 1, 3>, // scale_a * scale_b * alpha
|
||||
HostAccumulator<Gemm>,
|
||||
HostColBroadcast<Gemm, ElementD>,
|
||||
HostColBroadcast<Gemm, ElementD>
|
||||
>
|
||||
>,
|
||||
// D = activation(Z) * scaled_d, amax_d = max(abs(elements in D))
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>,
|
||||
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>,
|
||||
HEVT<
|
||||
HostScalarReduce<Gemm, amax, float>,
|
||||
HEVT<
|
||||
HostCompute<Gemm, ActivationFn>, //activation(Z) * scaled_d
|
||||
HostAccumulator<Gemm>, // Z
|
||||
HostAccumulator<Gemm> // Z
|
||||
>
|
||||
>,
|
||||
HostScalarBroadcast<Gemm, 1>, // scale_d
|
||||
HostScalarBroadcast<Gemm, 1> // scale_d
|
||||
>,
|
||||
// Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux))
|
||||
HEVT<
|
||||
HostAuxStore<Gemm, false, ElementD, cutlass::layout::RowMajor>,
|
||||
HEVT<
|
||||
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::Op>,
|
||||
HostCompute<Gemm, cutlass::epilogue::fusion::detail::ScaleOutOp<ElementD>::template Op>,
|
||||
HEVT<
|
||||
HostScalarReduce<Gemm, amax, float>,
|
||||
HostAccumulator<Gemm>
|
||||
|
||||
@ -126,7 +126,7 @@ gett(
|
||||
cudaStream_t stream = 0) {
|
||||
using namespace cute;
|
||||
|
||||
static_assert(rank(ProblemShapeMNKL{}) == 4);
|
||||
static_assert(cute::rank(ProblemShapeMNKL{}) == 4);
|
||||
auto M = get<0>(problem_shape_mnkl);
|
||||
auto N = get<1>(problem_shape_mnkl);
|
||||
auto K = get<2>(problem_shape_mnkl);
|
||||
|
||||
@ -431,11 +431,11 @@ void Gemm3x(
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename MainloopParams::LayoutB{}));
|
||||
static_assert(rank(typename EpilogueParams::LayoutC{}) == rank(typename EpilogueParams::LayoutD{}));
|
||||
static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{}));
|
||||
static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{}));
|
||||
static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{}));
|
||||
|
||||
if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) {
|
||||
if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) {
|
||||
Layout layout_A = make_layout_rank3(mainloop_params.A);
|
||||
Layout layout_B = make_layout_rank3(mainloop_params.B);
|
||||
Layout layout_C = make_layout_rank3(epilogue_params.C);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user