Updates for CUTLASS 3.4.1 (#1346)

* Updates for CUTLASS 3.4.1

* minor epi change
This commit is contained in:
ANIKET SHIVAM 2024-02-15 12:48:34 -08:00 committed by GitHub
parent 47a3ebbea9
commit bbe579a9e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 800 additions and 451 deletions

View File

@ -1,5 +1,11 @@
# NVIDIA CUTLASS Changelog
## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMMs](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
@ -8,7 +14,6 @@
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
* Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.

View File

@ -40,7 +40,25 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
project(CUTLASS VERSION 3.4.0 LANGUAGES CXX)
# To reduce duplicate version locations, parse the version out of the
# main versions.h file and reuse it here.
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/include/cutlass/version.h VERSION_FILE_CONTENTS)
string(REGEX MATCH "#define CUTLASS_MAJOR ([0-9]+)" _CUTLASS_VERSION_MAJOR "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_MAJOR ${CMAKE_MATCH_1})
string(REGEX MATCH "#define CUTLASS_MINOR ([0-9]+)" _CUTLASS_VERSION_MINOR "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_MINOR ${CMAKE_MATCH_1})
string(REGEX MATCH "#define CUTLASS_PATCH ([0-9]+)" _CUTLASS_VERSION_PATCH "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_PATCH ${CMAKE_MATCH_1})
message(STATUS "CUTLASS ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH}")
## CUTLASS PROJECT #############################################################
project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH} LANGUAGES CXX)
################################################################################
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
if (CUDA_VERSION VERSION_LESS 11.3)
@ -178,6 +196,9 @@ if(WIN32)
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_VERSIONS_GENERATED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_VERSIONS_GENERATED")
if (WIN32)
# Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors.
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3)
@ -589,8 +610,8 @@ if (NOT DEFINED CUTLASS_REVISION)
endif()
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version_extended.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version_extended.h
@ONLY)
target_include_directories(

View File

@ -2,7 +2,8 @@
## 2023
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi and Jay Shah. _arXiv_, December 2023.
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023.
- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023.

View File

@ -2,7 +2,7 @@
# CUTLASS 3.4
_CUTLASS 3.4 - January 2024_
_CUTLASS 3.4 - February 2024_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
@ -43,13 +43,18 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im
# What's New in CUTLASS 3.4
CUTLASS 3.4.1 is an update to CUTLASS adding:
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMM](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMM](/examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
CUTLASS 3.4.0 is an update to CUTLASS adding:
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
- [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
- Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
- Improvements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
- Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
Minimum requirements:
@ -93,8 +98,8 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
# Compatibility
CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2 and CUDA 12.3.1
performs best when built with the [**CUDA 12.3.2 Toolkit**](https://developer.nvidia.com/cuda-downloads).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2.
## Operating Systems
We have tested the following environments.

View File

@ -1,38 +0,0 @@
#include <cstdint>
#include <string>
#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@
#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@
#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
namespace cutlass {
inline uint32_t getVersion() {
return CUTLASS_VERSION;
}
inline uint32_t getVersionMajor() {
return CUTLASS_MAJOR;
}
inline uint32_t getVersionMinor() {
return CUTLASS_MINOR;
}
inline uint32_t getVersionPatch() {
return CUTLASS_PATCH;
}
inline uint32_t getVersionBuild() {
return CUTLASS_BUILD + 0;
}
inline std::string getVersionString() {
std::string version = "@CUTLASS_VERSION@";
if (getVersionBuild()) {
version += "." + std::to_string(getVersionBuild());
}
return version;
}
inline std::string getGitRevision() {
return "@CUTLASS_REVISION@";
}
} // namespace cutlass

View File

@ -0,0 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
#define CUTLASS_REVISION "@CUTLASS_REVISION@"

View File

@ -31,4 +31,5 @@
cutlass_example_add_executable(
02_dump_reg_shmem
dump_reg_shmem.cu
DISABLE_TESTS ON
)

View File

@ -70,7 +70,7 @@
using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
@ -98,8 +98,8 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedArray; // Epilogue to launch
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
@ -169,7 +169,7 @@ cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
@ -245,7 +245,7 @@ struct Result
bool passed = false;
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -468,7 +468,7 @@ int run(Options &options)
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -510,7 +510,7 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
run<Gemm>(options);
#endif

View File

@ -27,17 +27,17 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
# Only the correctness check will be run by these commands.
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=0) # Square problem sizes
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=0) # Square problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=1) # Default problem sizes
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Default problem sizes
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=0) # Default problem sizes
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Default problem sizes w/ Epilogue Op test
set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=0) # Small-k problem sizes
set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=0) # Small-k problem sizes
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes
set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=1) # Small-k problem sizes
cutlass_example_add_executable(
56_hopper_ptr_array_batched_gemm
@ -47,6 +47,8 @@ cutlass_example_add_executable(
TEST_SQUARE_LARGE_BATCH
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_BATCH
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_BATCH
TEST_SMALLK
TEST_SMALLK_LARGE_BATCH
)

View File

@ -44,6 +44,7 @@
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
Skipping any of the problem dimensions randomizes it across the different groups.
Same applies for alpha and beta values that are randomized across the different groups.
To run this example for a set of problems using the benchmark option:
@ -62,6 +63,7 @@
#include <fstream>
#include <sstream>
#include <vector>
#include <float.h>
#include "cutlass/cutlass.h"
@ -91,9 +93,9 @@ using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
using ElementC = float; // Element type for C and D matrix operands
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
@ -101,40 +103,40 @@ using ElementC = float; // Element type
// A matrix configuration
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedGroup; // Epilogue to launch
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC *, AlignmentC,
ElementC, LayoutC *, AlignmentC,
EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB *, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
@ -161,10 +163,10 @@ using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementAccumulator,
ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
// Host-side allocations
std::vector<int64_t> offset_A;
@ -177,6 +179,9 @@ std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
// Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
@ -197,7 +202,13 @@ cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
// Note, this is an array of pointers to alpha and beta scaling values per group
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
@ -208,8 +219,8 @@ struct Options {
bool help = false;
float alpha = 1.0f;
float beta = 0.0f;
float alpha = FLT_MAX;
float beta = FLT_MAX;
int iterations = 10;
int m = 1024, n = 2048, k = 512, groups = 10;
std::string benchmark_path;
@ -230,8 +241,8 @@ struct Options {
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX);
cmd.get_cmd_line_argument("beta", beta, FLT_MAX);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("benchmark", benchmark_path);
@ -248,10 +259,7 @@ struct Options {
}
void randomize_problems(cutlass::CommandLine &cmd) {
int cmd_line_m = -1;
int cmd_line_n = -1;
int cmd_line_k = -1;
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k);
@ -259,19 +267,15 @@ struct Options {
problem_sizes_host.reserve(groups);
for (int i = groups; i > 0; i--) {
int m = cmd_line_m;
int n = cmd_line_n;
int k = cmd_line_k;
if (m < 1) {
m = ((rand() % 512) + 1);
}
if (n < 1) {
n = ((rand() % 512) + 1);
}
if (k < 1) {
k = alignment * ((rand() % 64) + 1);
}
@ -317,6 +321,7 @@ struct Options {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
}
}
groups = static_cast<int>(problem_sizes_host.size());
return true;
}
@ -351,7 +356,9 @@ struct Options {
uint64_t fmas = uint64_t();
for (auto const & problem : problem_sizes_host) {
fmas += cute::size(problem);
fmas += static_cast<uint64_t>(get<0>(problem)) *
static_cast<uint64_t>(get<1>(problem)) *
static_cast<uint64_t>(get<2>(problem));
}
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * uint64_t(fmas);
@ -370,7 +377,7 @@ struct Result
bool passed = false;
};
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
@ -435,6 +442,7 @@ void allocate(const Options &options) {
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})));
}
block_A.reset(total_elements_A);
@ -442,6 +450,8 @@ void allocate(const Options &options) {
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
}
/// Initialize operands to be used in the GEMM and reference GEMM
@ -460,12 +470,18 @@ void initialize(const Options &options) {
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementC *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
}
ptr_A.reset(options.groups);
@ -492,13 +508,20 @@ void initialize(const Options &options) {
stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
}
/// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options)
typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
{
cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
@ -506,13 +529,36 @@ typename Gemm::Arguments args_from_options(const Options &options)
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{{options.alpha, options.beta}, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
typename Gemm::EpilogueOutputOp::Params params;
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
params = typename Gemm::EpilogueOutputOp::Params(
ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
}
else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
}
typename Gemm::Arguments arguments;
if (host_problem_shapes_available) {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
return arguments;
}
@ -539,10 +585,10 @@ bool verify(const Options &options) {
// Launch device reference gemm kernel
gemm_reference(
{M, N, K},
ElementAccumulator(options.alpha),
ElementAccumulator(alpha_host.at(i)),
ref_A,
ref_B,
ElementAccumulator(options.beta),
ElementAccumulator(beta_host.at(i)),
ref_C,
ref_D);
@ -560,7 +606,7 @@ bool verify(const Options &options) {
/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
int run(Options &options, bool host_problem_shapes_available = true)
{
allocate(options);
initialize(options);
@ -569,7 +615,7 @@ int run(Options &options)
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options);
auto arguments = args_from_options(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
@ -612,12 +658,12 @@ int run(Options &options)
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
std::cout << " Problem Sizes: " << std::endl;
for (auto const & problem : options.problem_sizes_host) {
std::cout << " " << problem << std::endl;
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
}
std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl;
}
@ -625,7 +671,7 @@ int run(Options &options)
return 0;
}
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -667,8 +713,9 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
run<Gemm>(options);
run<Gemm>(options, false /*host_problem_shapes_available*/);
#endif
return 0;

View File

@ -35,9 +35,15 @@ set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0)
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes
@ -49,8 +55,12 @@ cutlass_example_add_executable(
TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED
TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_RANDOM_PERF
TEST_RANDOM_PERF_LARGE_GROUP
)

View File

@ -265,7 +265,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
asm volatile (
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;"
:: "l"(smem_int64_desc), "r"(prob_shape[2]));
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4));

View File

@ -391,7 +391,7 @@ struct TiledMMA : MMA_Atom
} else {
return cute::max(core_size, perm_size);
}
CUTE_GCC_UNREACHABLE;
}

View File

@ -125,6 +125,9 @@ using CUTE_STL_NAMESPACE::invoke_result_t;
using CUTE_STL_NAMESPACE::common_type;
using CUTE_STL_NAMESPACE::common_type_t;
using CUTE_STL_NAMESPACE::remove_pointer;
using CUTE_STL_NAMESPACE::remove_pointer_t;
// <utility>
using CUTE_STL_NAMESPACE::declval;

View File

@ -64,6 +64,10 @@
#endif
#endif
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3)))
#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
#endif
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {

View File

@ -80,6 +80,40 @@ struct TagToStrideB<layout::ColumnMajor> {
using tag = layout::ColumnMajor;
};
// For each cutlass::layout *, provides its corresponding cute stride types, 64b by default
// Used by pointer array and grouped gemm
// Maps to modes [M, K, L]
template <>
struct TagToStrideA<layout::RowMajor *> {
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using type = UnderlyingType*;
using tag = layout::RowMajor;
};
// Maps to modes [M, K, L]
template <>
struct TagToStrideA<layout::ColumnMajor *> {
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using type = UnderlyingType*;
using tag = layout::ColumnMajor;
};
// Maps to modes [N, K, L]
template <>
struct TagToStrideB<layout::RowMajor *> {
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using type = UnderlyingType*;
using tag = layout::RowMajor;
};
// Maps to modes [N, K, L]
template <>
struct TagToStrideB<layout::ColumnMajor *> {
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using type = UnderlyingType*;
using tag = layout::ColumnMajor;
};
// Maps to modes [M, N, L]
template <class LayoutTag>
struct TagToStrideC : TagToStrideA<LayoutTag> { };
@ -101,7 +135,7 @@ template<int ModeIndex, class Stride>
constexpr bool
is_major(Stride = {}) {
// Account for stride types with and without batch mode and batch modes with static zero stride
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(Stride{})))>::value;
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(cute::remove_pointer_t<Stride>{})))>::value;
}
// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices

View File

@ -268,7 +268,7 @@ struct Sm90TmaBuilderImpl {
// Passing void C disables source load + smem allocation
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
@ -434,8 +434,7 @@ struct CollectiveBuilder<
Schedule,
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
cute::enable_if_t<cute::is_same_v<Schedule, NoSmemWarpSpecialized> ||
cute::is_same_v<Schedule, NoSmemWarpSpecializedArray> ||
cute::is_same_v<Schedule, NoSmemWarpSpecializedGroup> >> {
cute::is_same_v<Schedule, PtrArrayNoSmemWarpSpecialized> >> {
// Passing void C disables source load
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,

View File

@ -63,6 +63,7 @@ public:
// Type Aliases
//
using EpilogueSchedule = EpilogueSchedule_;
using DispatchPolicy = EpilogueSchedule_;
// derived types of output thread level operator
using ThreadEpilogueOp = ThreadEpilogueOp_;

View File

@ -73,12 +73,10 @@ public:
using ElementScalar = ElementCompute;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using UnderlyingStrideC = cute::remove_pointer_t<StrideC>;
using ElementD = typename ThreadEpilogueOp::ElementD;
using StrideD = StrideD_;
using StridesC = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
StrideC const*, StrideC>;
using StridesD = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
StrideD const*, StrideD>;
using UnderlyingStrideD = cute::remove_pointer_t<StrideD>;
using GmemTiledCopyC = void;
using GmemTiledCopyD = void;
@ -86,10 +84,9 @@ public:
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup> ||
cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedArray>, "Incompatible epilogue schedule.");
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::is_same_v<EpilogueSchedule, PtrArrayNoSmemWarpSpecialized>, "Incompatible epilogue schedule.");
static_assert(rank(UnderlyingStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(UnderlyingStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { };
@ -97,9 +94,9 @@ public:
struct Arguments {
typename ThreadEpilogueOp::Params thread{};
ElementC const** ptr_C = nullptr;
StridesC dC{};
StrideC dC{};
ElementD** ptr_D = nullptr;
StridesD dD{};
StrideD dD{};
};
// Device side epilogue params
@ -140,12 +137,13 @@ public:
CUTLASS_HOST_DEVICE
DefaultEpilogueArray(Params const& params_)
: params(params_), epilogue_op(params_.thread) { }
: params(params_) { }
CUTLASS_DEVICE
bool
is_source_needed() {
return epilogue_op.is_source_needed();
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
return true;
}
template<
@ -185,10 +183,23 @@ public:
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
StrideC stride_c;
StrideD stride_d;
if constexpr (cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>) {
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC[l_coord]);
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
// we get the correct alpha/beta values for the current batch/group using group index.
ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord);
if (epilogue_op.is_source_needed() && params.dC == nullptr) {
// Beta value is non-zero while pointer to C is a nullptr
assert(0);
}
UnderlyingStrideC stride_c;
UnderlyingStrideD stride_d;
if constexpr (!cute::is_same_v<UnderlyingStrideC, StrideC>) {
// If grouped gemm
if (epilogue_op.is_source_needed()) {
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC[l_coord]);
}
stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD[l_coord]);
}
else {
@ -197,7 +208,11 @@ public:
}
// Represent the full output tensor
Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C[l_coord]), make_shape(M,N,mock_L), stride_c); // (m,n,l)
ElementC const* ptr_C_l = nullptr;
if (epilogue_op.is_source_needed()) {
ptr_C_l = params.ptr_C[l_coord];
}
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l)
Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l)
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
@ -242,7 +257,6 @@ public:
private:
Params params;
ThreadEpilogueOp epilogue_op;
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -148,12 +148,12 @@ private:
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{});
using EmptyType = cute::tuple<>;
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
array_aligned<SmemElementC, size(SmemLayoutC{}), SmemAlignmentC>,
EmptyType>;
using SmemDStorage = cute::conditional_t<is_destination_supported,
using SmemDStorage = cute::conditional_t<is_destination_supported,
array_aligned<SmemElementD, size(SmemLayoutD{}), SmemAlignmentD>,
EmptyType>;
@ -189,6 +189,7 @@ public:
struct SharedStorage {
using TensorStorage = TensorStorageImpl;
TensorStorage tensors;
using PipelineStorage = typename LoadPipeline::SharedStorage;
@ -249,12 +250,12 @@ public:
Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC));
tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0));
}
typename Params::TMA_D tma_store_d;
if constexpr (is_destination_supported) {
Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutD{}(_,_,0));
}
}
return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
@ -385,13 +386,13 @@ public:
// Apply epilogue subtile, get matching smem tensor
SmemElementC* ptr_sC = nullptr;
if constexpr (is_source_supported) {
if constexpr (ReuseSmemC) {
ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D().data());
} else {
ptr_sC = shared_tensors.smem_C().data();
}
}
}
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
@ -559,7 +560,7 @@ public:
// Vectorized fragment view
constexpr int FragmentSize = DispatchPolicy::FragmentSize;
Tensor tRS_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(tRS_rAcc);
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly");
// (t)hread-partition for (s)mem to (r)egister copy (tSR_)

View File

@ -46,8 +46,7 @@ namespace cutlass::epilogue {
//////////////////////////////////////////////////////////////////////////////
struct NoSmemWarpSpecialized {};
struct NoSmemWarpSpecializedArray {};
struct NoSmemWarpSpecializedGroup {};
struct PtrArrayNoSmemWarpSpecialized {};
struct TmaWarpSpecialized {};
struct TmaWarpSpecializedCooperative {};
// DEPRECATED schedules, will be removed in next release

View File

@ -1247,6 +1247,7 @@ struct FusionCallbacks<
};
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <class FusionOpOrCallbacks, class = cute::void_t<>>
struct get_element_aux {
@ -1257,7 +1258,7 @@ template <class FusionOpOrCallbacks>
struct get_element_aux<FusionOpOrCallbacks, cute::void_t<typename FusionOpOrCallbacks::ElementAux>> {
using type = typename FusionOpOrCallbacks::ElementAux;
};
template <class NodeOp, class... ChildOps>
struct get_element_aux<Sm90TreeVisitor<NodeOp, ChildOps...>, cute::void_t<>> {
using type = typename get_element_aux<NodeOp>::type;
@ -1270,7 +1271,7 @@ struct get_element_aux<FusionCallbacks<Ts...>, cute::void_t<typename FusionCallb
public:
using type = typename get_element_aux<Operation>::type;
};
}
} // namespace cutlass:epilogue::fusion::detail
template <class Callbacks>
using get_element_aux_t = typename detail::get_element_aux<Callbacks>::type;

View File

@ -88,43 +88,72 @@ public:
/// Host-constructable parameters structure
struct Params
{
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
ElementCompute const* const* alpha_ptr_array; ///< array of pointers to accumulator scalar per group/batch
ElementCompute const* const* beta_ptr_array; ///< array of pointers to source scalar per group/batch
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
beta_ptr(nullptr),
alpha_ptr_array(nullptr),
beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta
):
alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { }
alpha(alpha), beta(beta),
alpha_ptr(nullptr), beta_ptr(nullptr),
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha
):
alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { }
alpha(alpha), beta(0),
alpha_ptr(nullptr), beta_ptr(nullptr),
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
):
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { }
alpha(0), beta(0),
alpha_ptr(alpha_ptr), beta_ptr(beta_ptr),
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr
):
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { }
alpha(0), beta(0),
alpha_ptr(alpha_ptr), beta_ptr(nullptr),
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute const* const* alpha_ptr_array,
ElementCompute const* const* beta_ptr_array
):
alpha(0), beta(0),
alpha_ptr(nullptr), beta_ptr(nullptr),
alpha_ptr_array(alpha_ptr_array), beta_ptr_array(beta_ptr_array) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute const* const* alpha_ptr_array
):
alpha(0), beta(0),
alpha_ptr(nullptr), beta_ptr(nullptr),
alpha_ptr_array(alpha_ptr_array), beta_ptr_array(nullptr) { }
};
private:
@ -140,9 +169,25 @@ public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombination(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
LinearCombination(Params const &params, int group_idx = 0) {
if (params.alpha_ptr_array != nullptr && params.alpha_ptr_array[group_idx] != nullptr) {
alpha_ = *(params.alpha_ptr_array[group_idx]);
}
else if (params.alpha_ptr != nullptr) {
alpha_ = *params.alpha_ptr;
}
else {
alpha_ = params.alpha;
}
if (params.beta_ptr_array != nullptr && params.beta_ptr_array[group_idx] != nullptr) {
beta_ = *(params.beta_ptr_array[group_idx]);
}
else if (params.beta_ptr != nullptr) {
beta_ = *params.beta_ptr;
}
else {
beta_ = params.beta;
}
}
/// Returns true if source is needed

View File

@ -185,8 +185,7 @@ struct CollectiveBuilder<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>) &&
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>
> {
static_assert(is_static<TileShape_MNK>::value);
@ -197,8 +196,7 @@ struct CollectiveBuilder<
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>);
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
@ -515,8 +513,7 @@ struct CollectiveBuilder<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>>
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>
> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
@ -534,8 +531,7 @@ struct CollectiveBuilder<
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>);
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
IsArrayOfPointersGemm,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;

View File

@ -93,8 +93,10 @@ struct CollectiveMma<
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using UnderlyingStrideA = cute::remove_pointer_t<StrideA>;
using ElementB = ElementB_;
using StrideB = StrideB_;
using UnderlyingStrideB = cute::remove_pointer_t<StrideB>;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
@ -149,14 +151,14 @@ struct CollectiveMma<
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(
GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(UnderlyingStrideA{}, int32_t(0)), UnderlyingStrideA{}),
SmemLayoutA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(UnderlyingStrideB{}, int32_t(0)), UnderlyingStrideB{}),
SmemLayoutB{}(_,_,cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
@ -179,16 +181,14 @@ struct CollectiveMma<
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>;
using StridesA = cute::conditional_t<IsGroupedGemmKernel, StrideA const*, StrideA>;
using StridesB = cute::conditional_t<IsGroupedGemmKernel, StrideB const*, StrideB>;
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<UnderlyingStrideA, StrideA>;
// Host side kernel arguments
struct Arguments {
ElementA const** ptr_A;
StridesA dA;
StrideA dA;
ElementB const** ptr_B;
StridesB dB;
StrideB dB;
};
// Device side kernel params
@ -197,9 +197,9 @@ struct CollectiveMma<
TMA_B tma_load_b;
void* tensormaps;
InternalElementA const** ptr_A;
StridesA dA;
StrideA dA;
InternalElementB const** ptr_B;
StridesB dB;
StrideB dB;
};
//
@ -212,30 +212,36 @@ struct CollectiveMma<
ProblemShape problem_shapes,
Arguments const& args,
void* workspace) {
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(0), 1);
auto [M,N,K,L] = problem_shape_MNKL;
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
// These will be replaced with correct values before the initial tma load.
auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1));
auto init_M = get<0>(init_shape);
auto init_N = get<1>(init_shape);
auto init_K = get<2>(init_shape);
// Batches/Groups are managed by using appropriate pointers to input matrices
const uint32_t mock_L = 1;
// These tensor pointers are only used to create tensormap/tma desc.
// This address to the tensor will be replaced with correct address before the initial tma load
InternalElementA const* ptr_A_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_A);
InternalElementB const* ptr_B_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_B);
cudaError_t cuda_error = cudaGetLastError(); // clear previous error
StrideA stride_a;
StrideB stride_b;
UnderlyingStrideA stride_a;
UnderlyingStrideB stride_b;
if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless
stride_a = StrideA{};
stride_b = StrideB{};
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
stride_a = UnderlyingStrideA{};
stride_b = UnderlyingStrideB{};
}
else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
init_M = get<0>(problem_shape_MNK);
init_N = get<1>(problem_shape_MNK);
init_K = get<2>(problem_shape_MNK);
stride_a = args.dA;
stride_b = args.dB;
}
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), stride_a));
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), stride_b));
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,mock_L), stride_a));
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,mock_L), stride_b));
TMA_A tma_load_a = make_tma_copy(
GmemTiledCopyA{},
tensor_a,
@ -287,12 +293,14 @@ struct CollectiveMma<
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
bool implementable = true;
// Check alignment for all problem sizes
for (int i = 0; i < problem_shapes.groups(); i++) {
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
if (problem_shapes.is_host_problem_shape_available()) {
// Check alignment for all problem sizes
for (int i = 0; i < problem_shapes.groups(); i++) {
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
auto [M,N,K,L] = problem_shape_MNKL;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), UnderlyingStrideA{});
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), UnderlyingStrideB{});
}
}
if (!implementable) {
@ -676,6 +684,14 @@ struct CollectiveMma<
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b,
prob_shape_B, prob_stride_B);
// Convert strides to byte strides
for (uint64_t& stride : prob_stride_A) {
stride = (stride * sizeof_bits_v<InternalElementA>) / 8;
}
for (uint64_t& stride : prob_stride_B) {
stride = (stride * sizeof_bits_v<InternalElementB>) / 8;
}
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A,
prob_shape_A,
prob_stride_A);

View File

@ -53,8 +53,7 @@ struct KernelTma { };
struct KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperative { };
struct KernelArrayTmaWarpSpecializedCooperative { };
struct KernelGroupTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedCooperative { };
//////////////////////////////////////////////////////////////////////////////
@ -67,8 +66,7 @@ struct KernelGroupTmaWarpSpecializedCooperative { };
struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { };
struct KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelArrayTmaWarpSpecializedCooperative { };
struct KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum : KernelGroupTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { };
// Policies to opt into mixed type GEMMs
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
@ -233,7 +231,7 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelGroupTmaWarpSpecializedCooperative
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
>
struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
constexpr static int Stages = Stages_;
@ -241,8 +239,7 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
using ArchTag = arch::Sm90;
using Schedule = KernelSchedule;
static_assert(
cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>,
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule>,
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
};

View File

@ -71,6 +71,12 @@ struct GroupProblemShape {
get_host_problem_shape(int32_t group_idx) const {
return host_problem_shapes[group_idx];
}
CUTLASS_HOST_DEVICE
bool
is_host_problem_shape_available() {
return host_problem_shapes != nullptr;
}
};
template <class ProblemShape_>
@ -104,6 +110,12 @@ public:
get_host_problem_shape(int32_t /* unused */ = 0) const {
return problem_shape_;
}
CUTLASS_HOST_DEVICE
bool
is_host_problem_shape_available() {
return true;
}
private:
UnderlyingProblemShape problem_shape_{};
};

View File

@ -62,8 +62,7 @@ class GemmUniversal<
CollectiveMainloop_,
CollectiveEpilogue_,
TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule> ||
cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>>
cute::enable_if_t<cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>>
>
{
public:
@ -80,7 +79,9 @@ public:
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using UnderlyingStrideA = typename CollectiveMainloop::UnderlyingStrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using UnderlyingStrideB = typename CollectiveMainloop::UnderlyingStrideB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using Schedule = typename DispatchPolicy::Schedule;
@ -93,8 +94,10 @@ public:
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using UnderlyingStrideC = typename CollectiveEpilogue::UnderlyingStrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using UnderlyingStrideD = typename CollectiveEpilogue::UnderlyingStrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
@ -102,7 +105,7 @@ public:
static_assert(cute::is_void_v<TileScheduler_>,
"Ptr-Array Cooperative and Grouped Gemm Cooperative kernel only supports the default scheduler.");
static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, Schedule>;
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<UnderlyingStrideA, StrideA>;
using TileScheduler = cute::conditional_t<IsGroupedGemmKernel,
typename detail::TileSchedulerSelector<
@ -204,7 +207,7 @@ public:
void* scheduler_workspace = workspace_ptr;
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups);
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset;
@ -244,14 +247,11 @@ public:
bool
can_implement(Arguments const& args) {
bool implementable = true;
if constexpr (cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, Schedule>) {
implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
} else if constexpr (IsGroupedGemmKernel) {
if constexpr (IsGroupedGemmKernel) {
// Group GEMM currently only supports rank-3 problem shapes
implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3);
}
else {
implementable = false;
} else {
implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n");
@ -269,7 +269,7 @@ public:
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
workspace_size += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
@ -297,9 +297,9 @@ public:
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
status = TileScheduler::template initialize_workspace<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
@ -350,23 +350,20 @@ public:
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(UnderlyingStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(UnderlyingStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(UnderlyingStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(UnderlyingStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
enum class WarpGroupRole {
@ -441,8 +438,6 @@ public:
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// Purpose of maintaining this pipeline state is to make sure TMA loads have finished before doing descriptor updates
typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
@ -554,7 +549,8 @@ public:
shared_storage.tensors.mainloop
);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(work_k_tile_count);
// Wait for the last TMA stage to complete loading, before issuing tensormap updates
mainloop_pipe_producer_state.advance(work_k_tile_count - 1);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive) {
@ -570,8 +566,10 @@ public:
if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(next_batch), Int<1>{});
}
// Wait for the last TMA stage to complete loading, before issuing tensormap updates
mainloop_pipe_tma_consumer_state.advance(work_k_tile_count-1);
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
// Since this state is waiting for loads to finish, it must start in the inverted phase.
typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state =
{mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()};
mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state);
collective_mainloop.tensormaps_perform_update(
shared_storage.tensormaps.mainloop,
@ -585,13 +583,9 @@ public:
// Entire warp must do this (ie its aligned)
collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps);
curr_batch = next_batch;
// Advance the TMA consumer state for the last remaining stage that was being waited for above
mainloop_pipe_tma_consumer_state.advance(1);
}
else if (work_tile_info.is_valid()) { // case where batch/group didn't change between tiles
// Advance the TMA consumer state for all the stages to be in sync
mainloop_pipe_tma_consumer_state.advance(work_k_tile_count);
}
// Advance the producer state for the last remaining stage that was being waited for above
mainloop_pipe_producer_state.advance(1);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
@ -720,6 +714,7 @@ public:
);
}
} // Consumer Warp Groups End
#endif
}
private:

View File

@ -211,13 +211,10 @@ public:
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -311,6 +308,7 @@ public:
thread_idx,
smem_buf
);
#endif
}
};

View File

@ -219,13 +219,10 @@ public:
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
enum class WarpGroupRole {
Producer = 0,
@ -435,6 +432,7 @@ public:
epi_store_pipe_producer_state_next
);
}
#endif
}
};

View File

@ -298,13 +298,10 @@ public:
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
@ -610,6 +607,7 @@ public:
);
}
} // Consumer Warp Groups End
#endif
}
private:

View File

@ -296,13 +296,10 @@ public:
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -612,6 +609,7 @@ public:
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
} // Consumer Warp Groups End
#endif
}
};

View File

@ -223,13 +223,10 @@ public:
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
enum class WarpGroupRole {
Producer = 0,
@ -409,6 +406,7 @@ public:
shared_storage.tensors.epilogue
);
}
#endif
}
};

View File

@ -250,13 +250,10 @@ public:
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -493,6 +490,7 @@ public:
);
}
} // Consumer Warp Groups End
#endif
}
private:

View File

@ -257,13 +257,10 @@ public:
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
return;
}
#endif
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -509,6 +506,7 @@ public:
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
} // Consumer Warp Groups End
#endif
}
};

View File

@ -55,7 +55,7 @@ private:
// Tracking current group, its starting linear idx and total tiles
struct GroupInfo {
uint64_t group = 0;
int group_idx = 0;
uint64_t start_linear_idx = 0;
uint64_t total_tiles = 0;
} current_group_info_;
@ -115,7 +115,7 @@ public:
GroupProblemShape problem_shapes,
TileShape tile_shape,
ClusterShape cluster_shape,
[[maybe_unused]] KernelHardwareInfo const& hw_info,
KernelHardwareInfo const& hw_info,
Arguments const& arguments,
[[maybe_unused]] void* workspace=nullptr,
[[maybe_unused]] const uint32_t epilogue_subtile = 1) {
@ -126,14 +126,16 @@ public:
dim3 problem_blocks = get_tiled_cta_shape_mnl(
problem_shapes.groups(),
reinterpret_cast<ProblemShape const*>(problem_shapes.host_problem_shapes),
problem_shapes,
hw_info,
tile_shape, cluster_shape);
Params params;
params.initialize(
problem_blocks,
problem_shapes.groups(),
reinterpret_cast<ProblemShape*>(problem_shapes.problem_shapes),
problem_shapes.problem_shapes,
problem_shapes.host_problem_shapes,
to_gemm_coord(tile_shape),
to_gemm_coord(cluster_shape),
hw_info,
@ -144,6 +146,64 @@ public:
return params;
}
// Given the inputs, computes the physical grid we should launch.
template<class TileShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
GroupProblemShape problem_shapes,
TileShape tile_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments,
bool truncate_by_problem_size=true) {
dim3 problem_blocks = get_tiled_cta_shape_mnl(
problem_shapes.groups(),
problem_shapes,
hw_info,
tile_shape, cluster_shape);
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order,
/* truncate_by_problem_size = */true
);
}
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_cta_shape_mnl(int groups, GroupProblemShape problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) {
uint32_t total_ctas = 0;
uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here
// If host problem shapes are not provided.
if (!problem_shapes.is_host_problem_shape_available()) {
total_ctas = hw_info.sm_count;
}
// If host problem shapes are provided, make a better decision about possibility to launch smaller grid.
else {
for (int group = 0; group < groups; group++) {
auto ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes.get_host_problem_shape(group)), cute::shape<0>(cta_shape)));
auto ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes.get_host_problem_shape(group)), cute::shape<1>(cta_shape)));
auto problem_blocks_m = round_up(ctas_along_m, cute::get<0>(cluster_shape));
auto problem_blocks_n = round_up(ctas_along_n, cute::get<1>(cluster_shape));
total_ctas += problem_blocks_m * problem_blocks_n;
}
}
return Params::get_tiled_cta_shape_mnl(
to_gemm_coord(cluster_shape),
total_ctas, cta_in_N_dim
);
}
CUTLASS_HOST_DEVICE
static bool
can_implement(Arguments const& args) {
@ -156,7 +216,7 @@ public:
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__.
#if defined(__CUDA_ARCH__)
if (params_.raster_order_ == RasterOrder::AlongN) {
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x);
}
else {
@ -165,9 +225,19 @@ public:
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), params_.cta_shape_.m()));
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), params_.cta_shape_.n()));
current_group_info_.total_tiles = cta_m * cta_n;
uint64_t ctas_along_m, ctas_along_n;
if (is_tuple<decltype(cute::shape<0>(params_.problem_shapes_[0]))>::value ||
is_tuple<decltype(cute::shape<1>(params_.problem_shapes_[0]))>::value) {
ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.m()));
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.n()));
}
else {
ctas_along_m = scheduler_params.divmod_cta_shape_m_.divide(cute::shape<0>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_m_.divisor - 1);
ctas_along_n = scheduler_params.divmod_cta_shape_n_.divide(cute::shape<1>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_n_.divisor - 1);
}
auto problem_blocks_m = round_up(ctas_along_m, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.m());
auto problem_blocks_n = round_up(ctas_along_n, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.n());
current_group_info_.total_tiles = problem_blocks_m * problem_blocks_n;
#else
CUTLASS_ASSERT(false && "This line should never be reached");
#endif
@ -182,24 +252,22 @@ public:
CUTLASS_DEVICE
WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx) {
if (linear_idx >= scheduler_params.blocks_per_problem_) {
if (scheduler_params.pre_processed_problem_shapes && linear_idx >= scheduler_params.blocks_across_problem_) {
return WorkTileInfo::invalid_work_tile();
}
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(linear_idx);
auto [work_idx_m, work_idx_n, new_group_info, valid_tile] = get_work_idx_m_and_n(blk_per_grid_dim,
current_group_info_,
scheduler_params.groups_,
scheduler_params.problem_shapes_,
scheduler_params.cta_shape_,
scheduler_params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_minor_,
scheduler_params.log_swizzle_size_,
scheduler_params.raster_order_);
current_group_info_ = new_group_info;
return {work_idx_m, work_idx_n, static_cast<int>(current_group_info_.group), valid_tile};
return get_work_idx_m_and_n(linear_idx,
current_group_info_,
scheduler_params.groups_,
scheduler_params.problem_shapes_,
scheduler_params.cta_shape_,
scheduler_params.cluster_shape_,
scheduler_params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_minor_,
scheduler_params.divmod_cta_shape_m_,
scheduler_params.divmod_cta_shape_n_,
scheduler_params.log_swizzle_size_,
scheduler_params.raster_order_);
}
CUTLASS_DEVICE
@ -208,34 +276,62 @@ public:
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
}
// get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle
// get work_idx_m, work_idx_n from linear_idx while applying swizzle
static CUTLASS_DEVICE
cute::tuple<int32_t, int32_t, struct GroupInfo, bool>
WorkTileInfo
get_work_idx_m_and_n(
uint64_t blk_per_grid_dim,
struct GroupInfo group_info,
uint64_t linear_idx,
struct GroupInfo& group_info,
int32_t total_problem_groups,
ProblemShape* problem_shapes,
GemmCoord cta_shape,
GemmCoord cluster_shape,
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
FastDivmodU64 const& divmod_cta_shape_m,
FastDivmodU64 const& divmod_cta_shape_n,
int32_t log_swizzle_size,
RasterOrder raster_order) {
bool valid_tile = true;
int cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m()));
int cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n()));
uint64_t ctas_along_m, ctas_along_n;
if (is_tuple<decltype(cute::shape<0>(problem_shapes[group_info.group_idx]))>::value ||
is_tuple<decltype(cute::shape<1>(problem_shapes[group_info.group_idx]))>::value) {
ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group_idx]), cta_shape.m()));
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group_idx]), cta_shape.n()));
}
else {
ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_m.divisor - 1);
ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_n.divisor - 1);
}
auto problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m());
auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n());
group_info.total_tiles = problem_blocks_m * problem_blocks_n;
while (group_info.start_linear_idx + group_info.total_tiles <= linear_idx) {
group_info.group_idx++;
if (group_info.group_idx >= total_problem_groups)
return WorkTileInfo::invalid_work_tile();
while (group_info.start_linear_idx + group_info.total_tiles <= blk_per_grid_dim) {
group_info.group++;
group_info.start_linear_idx += group_info.total_tiles;
cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m()));
cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n()));
group_info.total_tiles = cta_m * cta_n;
if (is_tuple<decltype(cute::shape<0>(problem_shapes[group_info.group_idx]))>::value ||
is_tuple<decltype(cute::shape<1>(problem_shapes[group_info.group_idx]))>::value) {
ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group_idx]), cta_shape.m()));
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group_idx]), cta_shape.n()));
}
else {
ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_m.divisor - 1);
ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_n.divisor - 1);
}
problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m());
problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n());
group_info.total_tiles = problem_blocks_m * problem_blocks_n;
}
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim - group_info.start_linear_idx);
uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx);
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
if (raster_order == RasterOrder::AlongN) {
@ -252,8 +348,13 @@ public:
offset = cluster_id & ((1 << log_swizzle_size) - 1);
extra = cluster_id >> log_swizzle_size;
uint64_t curr_group_cluster_blk_major, remainder;
divmod_cluster_shape_major(curr_group_cluster_blk_major, remainder, cta_m);
uint64_t curr_group_cluster_blk_major;
if (raster_order == RasterOrder::AlongN) {
curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_n);
}
else {
curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_m);
}
cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major;
cluster_idx_major = extra % curr_group_cluster_blk_major;
@ -265,61 +366,14 @@ public:
cluster_major_offset);
if (raster_order == RasterOrder::AlongN) {
return {minor_work_idx, major_work_idx, group_info, valid_tile};
return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile};
}
else {
return {major_work_idx, minor_work_idx, group_info, valid_tile};
return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile};
}
}
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_cta_shape_mnl(int groups, ProblemShape const* problem_shapes, BlockShape cta_shape, ClusterShape cluster_shape) {
uint32_t total_ctas = 0;
uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here
for (int group = 0; group < groups; group++) {
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group]), cute::shape<0>(cta_shape)));
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group]), cute::shape<1>(cta_shape)));
total_ctas += cta_m * cta_n;
}
return Params::get_tiled_cta_shape_mnl(
to_gemm_coord(cluster_shape),
total_ctas, cta_in_N_dim
);
}
// Given the inputs, computes the physical grid we should launch.
template<class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
GroupProblemShape problem_shapes,
BlockShape cta_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments,
bool truncate_by_problem_size=true) {
dim3 problem_blocks = get_tiled_cta_shape_mnl(
problem_shapes.groups(),
reinterpret_cast<ProblemShape const*>(problem_shapes.host_problem_shapes),
cta_shape, cluster_shape);
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order,
/* truncate_by_problem_size = */true
);
}
// Returns whether the block assigned this work should compute the epilogue for the corresponding
// output tile. For the basic tile scheduler, this is always true.
CUTLASS_HOST_DEVICE

View File

@ -1273,15 +1273,18 @@ struct PersistentTileSchedulerSm90GroupParams {
FastDivmodU64Pow2 divmod_cluster_shape_major_{};
FastDivmodU64Pow2 divmod_cluster_shape_minor_{};
FastDivmodU64 divmod_batch_{};
FastDivmodU64 divmod_cta_shape_m_{};
FastDivmodU64 divmod_cta_shape_n_{};
uint64_t blocks_per_problem_ = 0;
uint64_t blocks_across_problem_ = 0;
bool pre_processed_problem_shapes = true;
int32_t log_swizzle_size_ = 0;
RasterOrder raster_order_ = RasterOrder::AlongN;
int32_t groups_ = 0;
ProblemShape* problem_shapes_ = nullptr;
GemmCoord cta_shape_;
GemmCoord cluster_shape_;
// Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
@ -1291,6 +1294,7 @@ struct PersistentTileSchedulerSm90GroupParams {
dim3 problem_blocks,
int32_t groups,
ProblemShape* problem_shapes,
ProblemShape const* host_problem_shapes,
GemmCoord cta_shape,
GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info,
@ -1317,11 +1321,12 @@ struct PersistentTileSchedulerSm90GroupParams {
groups_ = groups;
problem_shapes_ = problem_shapes;
cta_shape_ = cta_shape;
cluster_shape_ = cluster_shape;
blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks.z;
blocks_across_problem_ = problem_blocks.x * problem_blocks.y * problem_blocks.z;
pre_processed_problem_shapes = (host_problem_shapes == nullptr) ? false : true;
log_swizzle_size_ = log_swizzle_size;
raster_order_ = raster_order;
divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n);
if (raster_order == RasterOrder::AlongN) {
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n());
@ -1331,6 +1336,9 @@ struct PersistentTileSchedulerSm90GroupParams {
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m());
divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n());
}
divmod_cta_shape_m_ = FastDivmodU64(cta_shape_.m());
divmod_cta_shape_n_ = FastDivmodU64(cta_shape_.n());
}
// Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions.
@ -1344,8 +1352,8 @@ struct PersistentTileSchedulerSm90GroupParams {
auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n();
return {
static_cast<uint32_t>(problem_blocks_m),
static_cast<uint32_t>(problem_blocks_n),
static_cast<uint32_t>(cta_m),
static_cast<uint32_t>(cta_n),
static_cast<uint32_t>(1) // Only a single batch per group is currently supported
};
}

80
include/cutlass/version.h Normal file
View File

@ -0,0 +1,80 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cstdint>
#include <string>
#define CUTLASS_MAJOR 3
#define CUTLASS_MINOR 4
#define CUTLASS_PATCH 1
#ifdef CUTLASS_VERSIONS_GENERATED
#include "cutlass/version_extended.h"
#else
#define CUTLASS_BUILD 0
#define CUTLASS_REVISION ""
#endif
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
namespace cutlass {
inline constexpr uint32_t getVersion() {
return CUTLASS_VERSION;
}
inline constexpr uint32_t getVersionMajor() {
return CUTLASS_MAJOR;
}
inline constexpr uint32_t getVersionMinor() {
return CUTLASS_MINOR;
}
inline constexpr uint32_t getVersionPatch() {
return CUTLASS_PATCH;
}
inline constexpr uint32_t getVersionBuild() {
return CUTLASS_BUILD + 0;
}
inline std::string getVersionString() {
std::string version = "@CUTLASS_VERSION@";
if (getVersionBuild()) {
version += "." + std::to_string(getVersionBuild());
}
return version;
}
inline std::string getGitRevision() {
return "@CUTLASS_REVISION@";
}
} // namespace cutlass

View File

@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
[project]
name = "nvidia-cutlass"
version = "3.4.0.0"
version = "3.4.1.0"
description = "CUTLASS"
readme = "README.md"
requires-python = ">=3.8"
license = {file = "LICENSE.txt"}
license = {text = "BSD-3-Clause"}
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",

View File

@ -40,7 +40,7 @@ import cutlass_library
def _cuda_install_path_from_nvcc() -> str:
import subprocess
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
result = subprocess.run(['which', 'nvcc'], capture_output=True)
result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True)
if result.returncode != 0:
raise Exception(f'Unable to find nvcc via `which` utility.')
@ -121,7 +121,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '3.4.0'
this.__version__ = '3.4.1'
from cutlass.backend import create_memory_pool
from cutlass.emit.pytorch import pytorch
@ -169,7 +169,7 @@ def initialize_cuda_context():
raise Exception("No CUDA devices found")
device_id = 0
this._device_id = device_id
this._device_id = int(device_id)
def device_id() -> int:

View File

@ -213,8 +213,12 @@ def get_mainloop_arguments_3x(
return _MainloopArgumentsTma
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue):
if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt
else:
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
if hasattr(epilogue_functor, "visitor"):
class _EpilogueArguments(ctypes.Structure):
_fields_ = [

View File

@ -157,19 +157,41 @@ class LinearCombination(EpilogueFunctorBase):
c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue
class _EpilogueOutputOpParams(ctypes.Structure):
class _EpilogueOutputOpParamsEVT(ctypes.Structure):
"""
Epilogue params when using the default linear combination of EVT, which
does not currently use {alpha,beta}_ptr_array
"""
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p)
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("alpha_ptr_array", ctypes.c_void_p),
("beta_ptr_array", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
def to_evt_params(self) -> _EpilogueOutputOpParamsEVT:
return _EpilogueOutputOpParamsEVT(self.alpha, self.beta)
self.epilogue_type = _EpilogueOutputOpParams
self.epilogue_type_evt = _EpilogueOutputOpParamsEVT
def emit(self):
return super().emit(self.tag, self.template_arguments)

View File

@ -241,10 +241,10 @@ class EVTFrontendBase:
:param name: the name of the graph
"""
drawer = EVTGraphDrawer(self.dag_ir, name)
if drawer.dot_available:
try:
for name, graph in drawer.get_dot_graph():
graph.write_svg(f"./{name}.svg")
else:
except:
raise RuntimeError(
"'dot' is not found in path. GraphDrawer is disabled. "
"Please install it with 'sudo apt-get install graphviz'."

View File

@ -61,22 +61,6 @@ class EVTGraphDrawer:
self._dot_graphs = {}
self._dot_graphs[name] = self._to_dot(graph, name)
self.dot_available = self._check_dot_availability()
def _check_dot_availability(self):
"""
Check if graphviz is installed
"""
try:
# Run the 'dot' command and capture its output
result = subprocess.run(
["dot", "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Check if the command was successful and the output contains version information
if result.returncode == 0 and "dot - graphviz" in result.stderr:
return True
except FileNotFoundError:
pass
return False
def _get_node_style(self, node):
template = {

View File

@ -325,7 +325,7 @@ class GemmArguments2x(ArgumentBase):
def initialize(self):
launch_config = self.operation.rt_module.plan(self)
# Get the host and evice workspace
# Get the host and device workspace
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
if device_workspace_size > 0:
@ -512,6 +512,18 @@ class GemmArguments3x(GemmArguments2x):
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
def get_arguments(self):
mainloop_args = get_mainloop_arguments_3x(
self.operation.tile_description.kernel_schedule,
self.operation.A.element,
self.operation.B.element,
self.operation.A.alignment,
self.operation.B.alignment
)
scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
if self.batch_count > 1:
@ -539,9 +551,12 @@ class GemmArguments3x(GemmArguments2x):
)
# Set of mainloop arguments needed for this kernel
mainloop = self.operation.rt_module.mainloop_args.from_generic_mainloop_args(generic_args)
mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
epilogue = self.operation.rt_module.epilogue_args(
if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
self.output_op = self.output_op.to_evt_params()
epilogue = epilogue_args(
self.output_op,
int(self.ptr_C),
stride_C,
@ -550,15 +565,15 @@ class GemmArguments3x(GemmArguments2x):
)
# Set hardware info
hw_info = self.operation.rt_module.hw_info(0, device_sm_count())
hw_info_ = hw_info(0, device_sm_count())
self.arguments = self.operation.argument_type(
self.arguments = argument_type(
int(self.gemm_mode),
problem_size_,
mainloop,
epilogue,
hw_info,
self.operation.rt_module.scheduler_args
hw_info_,
scheduler_args
)
return self.arguments
@ -1119,6 +1134,10 @@ extern "C" {
using GemmType = ${operation_name}_base;
bool ${operation_name}_uses_default_epilogue() {
return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
}
// Get the workspace size
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
return GemmType::get_workspace_size(*argument);
@ -1163,19 +1182,10 @@ extern "C" {
"get_grid_shape": dim3_,
"get_block_shape": dim3_,
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
"get_kernel_workspace_size": ctypes.c_uint64
"get_kernel_workspace_size": ctypes.c_uint64,
"uses_default_epilogue": ctypes.c_bool,
}
self.emitter = EmitGemmUniversalInstance3x("_type")
self.mainloop_args = get_mainloop_arguments_3x(
operation.tile_description.kernel_schedule,
operation.A.element,
operation.B.element,
operation.A.alignment,
operation.B.alignment
)
self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler)
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(
self.mainloop_args, operation.epilogue_functor, self.scheduler_args)
def get_device_workspace_size(self, arguments: GemmArguments3x):
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='3.4.0',
version='3.4.1',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='3.4.0',
version='3.4.1',
description='Python implementation of CuTe',
packages=['pycute'],
)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -56,9 +56,6 @@
#include "gemm_testbed_3x_evt.hpp"
#include "sm90_evt_operations.hpp"
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
@ -132,7 +129,7 @@ bool testEVTAuxStoreWithoutD() {
D_block.reset(m * n);
aux_store_D_block.reset(m * n);
Gemm gemm_op_base;
auto stride_A = cutlass::make_cute_packed_stride(
typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{}));
auto stride_B = cutlass::make_cute_packed_stride(
@ -141,7 +138,7 @@ bool testEVTAuxStoreWithoutD() {
typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{}));
auto stride_D = cutlass::make_cute_packed_stride(
typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{}));
auto arguments_base = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
@ -178,12 +175,12 @@ bool testEVTAuxStoreWithoutD() {
/*hw_info=*/{},
/*scheduler_args=*/{}
};
constexpr float beta [[maybe_unused]] = 1.0;
constexpr float alpha [[maybe_unused]] = 1.0;
using ElementC = typename GemmWithoutD::ElementC;
if constexpr (not has_c) {
arguments_base.epilogue.thread = {
// binary op : alpha * acc
@ -282,7 +279,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
>;
using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -324,10 +321,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr);
};
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -352,7 +349,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
>;
using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -394,10 +391,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr);
};
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -467,7 +464,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -492,7 +489,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
>;
using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -534,10 +531,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr);
};
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -562,7 +559,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
>;
using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -604,10 +601,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr);
};
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -677,7 +674,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;

View File

@ -35,7 +35,6 @@
#pragma once
#include "cute/layout.hpp"
#include "cute/stride.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////