Updates for CUTLASS 3.4.1 (#1346)

* Updates for CUTLASS 3.4.1

* minor epi change
This commit is contained in:
ANIKET SHIVAM 2024-02-15 12:48:34 -08:00 committed by GitHub
parent 47a3ebbea9
commit bbe579a9e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 800 additions and 451 deletions

View File

@ -1,5 +1,11 @@
# NVIDIA CUTLASS Changelog # NVIDIA CUTLASS Changelog
## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12) ## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMMs](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors. * Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) * Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above). * Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
@ -8,7 +14,6 @@
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released. * NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
* Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved. * Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31) ## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. * [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}. * [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.

View File

@ -40,7 +40,25 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}") message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
project(CUTLASS VERSION 3.4.0 LANGUAGES CXX) # To reduce duplicate version locations, parse the version out of the
# main versions.h file and reuse it here.
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/include/cutlass/version.h VERSION_FILE_CONTENTS)
string(REGEX MATCH "#define CUTLASS_MAJOR ([0-9]+)" _CUTLASS_VERSION_MAJOR "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_MAJOR ${CMAKE_MATCH_1})
string(REGEX MATCH "#define CUTLASS_MINOR ([0-9]+)" _CUTLASS_VERSION_MINOR "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_MINOR ${CMAKE_MATCH_1})
string(REGEX MATCH "#define CUTLASS_PATCH ([0-9]+)" _CUTLASS_VERSION_PATCH "${VERSION_FILE_CONTENTS}")
set(_CUTLASS_VERSION_PATCH ${CMAKE_MATCH_1})
message(STATUS "CUTLASS ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH}")
## CUTLASS PROJECT #############################################################
project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH} LANGUAGES CXX)
################################################################################
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
if (CUDA_VERSION VERSION_LESS 11.3) if (CUDA_VERSION VERSION_LESS 11.3)
@ -178,6 +196,9 @@ if(WIN32)
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE) set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
endif() endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_VERSIONS_GENERATED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_VERSIONS_GENERATED")
if (WIN32) if (WIN32)
# Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors. # Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors.
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3)
@ -589,8 +610,8 @@ if (NOT DEFINED CUTLASS_REVISION)
endif() endif()
configure_file( configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version_extended.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version_extended.h
@ONLY) @ONLY)
target_include_directories( target_include_directories(

View File

@ -2,7 +2,8 @@
## 2023 ## 2023
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi and Jay Shah. _arXiv_, December 2023. - ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023.
- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023. - ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023.

View File

@ -2,7 +2,7 @@
# CUTLASS 3.4 # CUTLASS 3.4
_CUTLASS 3.4 - January 2024_ _CUTLASS 3.4 - February 2024_
CUTLASS is a collection of CUDA C++ template abstractions for implementing CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
@ -43,13 +43,18 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im
# What's New in CUTLASS 3.4 # What's New in CUTLASS 3.4
CUTLASS 3.4.1 is an update to CUTLASS adding:
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
- Improvements for Hopper [Group-GEMM](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMM](/examples/56_hopper_ptr_array_batched_gemm).
- Updates and bugfixes from the community (thanks!).
CUTLASS 3.4.0 is an update to CUTLASS adding: CUTLASS 3.4.0 is an update to CUTLASS adding:
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100. - Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above) - Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above) - Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
- [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now. - [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
- Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library. - Improvements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
- Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved. - Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
Minimum requirements: Minimum requirements:
@ -93,8 +98,8 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
# Compatibility # Compatibility
CUTLASS requires a C++17 host compiler and CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive). performs best when built with the [**CUDA 12.3.2 Toolkit**](https://developer.nvidia.com/cuda-downloads).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2 and CUDA 12.3.1 It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2.
## Operating Systems ## Operating Systems
We have tested the following environments. We have tested the following environments.

View File

@ -1,38 +0,0 @@
#include <cstdint>
#include <string>
#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@
#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@
#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
namespace cutlass {
inline uint32_t getVersion() {
return CUTLASS_VERSION;
}
inline uint32_t getVersionMajor() {
return CUTLASS_MAJOR;
}
inline uint32_t getVersionMinor() {
return CUTLASS_MINOR;
}
inline uint32_t getVersionPatch() {
return CUTLASS_PATCH;
}
inline uint32_t getVersionBuild() {
return CUTLASS_BUILD + 0;
}
inline std::string getVersionString() {
std::string version = "@CUTLASS_VERSION@";
if (getVersionBuild()) {
version += "." + std::to_string(getVersionBuild());
}
return version;
}
inline std::string getGitRevision() {
return "@CUTLASS_REVISION@";
}
} // namespace cutlass

View File

@ -0,0 +1,34 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
#define CUTLASS_REVISION "@CUTLASS_REVISION@"

View File

@ -31,4 +31,5 @@
cutlass_example_add_executable( cutlass_example_add_executable(
02_dump_reg_shmem 02_dump_reg_shmem
dump_reg_shmem.cu dump_reg_shmem.cu
DISABLE_TESTS ON
) )

View File

@ -70,7 +70,7 @@
using namespace cute; using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations /// GEMM kernel configurations
@ -98,8 +98,8 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelArrayTmaWarpSpecializedCooperative; // Kernel to launch using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedArray; // Epilogue to launch using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
@ -169,7 +169,7 @@ cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D; cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D; cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types /// Testbed utility types
@ -245,7 +245,7 @@ struct Result
bool passed = false; bool passed = false;
}; };
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation /// GEMM setup and evaluation
@ -468,7 +468,7 @@ int run(Options &options)
return 0; return 0;
} }
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
@ -510,7 +510,7 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels // Evaluate CUTLASS kernels
// //
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
run<Gemm>(options); run<Gemm>(options);
#endif #endif

View File

@ -27,17 +27,17 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes
# Only the correctness check will be run by these commands. set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=0) # Square problem sizes set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=1) # Default problem sizes
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=0) # Square problem sizes set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Default problem sizes set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Default problem sizes w/ Epilogue Op test
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=0) # Default problem sizes set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=0) # Small-k problem sizes set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes
set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=0) # Small-k problem sizes set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=1) # Small-k problem sizes
cutlass_example_add_executable( cutlass_example_add_executable(
56_hopper_ptr_array_batched_gemm 56_hopper_ptr_array_batched_gemm
@ -47,6 +47,8 @@ cutlass_example_add_executable(
TEST_SQUARE_LARGE_BATCH TEST_SQUARE_LARGE_BATCH
TEST_EPILOGUE TEST_EPILOGUE
TEST_EPILOGUE_LARGE_BATCH TEST_EPILOGUE_LARGE_BATCH
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_BATCH
TEST_SMALLK TEST_SMALLK
TEST_SMALLK_LARGE_BATCH TEST_SMALLK_LARGE_BATCH
) )

View File

@ -44,6 +44,7 @@
The above example command makes all 10 groups to be sized at the given m, n, k sizes. The above example command makes all 10 groups to be sized at the given m, n, k sizes.
Skipping any of the problem dimensions randomizes it across the different groups. Skipping any of the problem dimensions randomizes it across the different groups.
Same applies for alpha and beta values that are randomized across the different groups.
To run this example for a set of problems using the benchmark option: To run this example for a set of problems using the benchmark option:
@ -62,6 +63,7 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <float.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
@ -91,9 +93,9 @@ using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
using ElementC = float; // Element type for C and D matrix operands using ElementC = cutlass::half_t; // Element type for C and D matrix operands
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations /// GEMM kernel configurations
@ -101,40 +103,40 @@ using ElementC = float; // Element type
// A matrix configuration // A matrix configuration
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration // B matrix configuration
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration // C/D matrix configuration
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
// Core kernel configurations // Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedGroup; // Epilogue to launch using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC, ElementC, LayoutC *, AlignmentC,
ElementC, LayoutC, AlignmentC, ElementC, LayoutC *, AlignmentC,
EpilogueSchedule EpilogueSchedule
>::CollectiveOp; >::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA, ElementA, LayoutA *, AlignmentA,
ElementB, LayoutB, AlignmentB, ElementB, LayoutB *, AlignmentB,
ElementAccumulator, ElementAccumulator,
TileShape, ClusterShape, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout< cutlass::gemm::collective::StageCountAutoCarveout<
@ -161,10 +163,10 @@ using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementAccumulator, ElementAccumulator,
ElementAccumulator>; ElementAccumulator>;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
using StrideB = typename Gemm::GemmKernel::StrideB; using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
using StrideC = typename Gemm::GemmKernel::StrideC; using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
using StrideD = typename Gemm::GemmKernel::StrideD; using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
// Host-side allocations // Host-side allocations
std::vector<int64_t> offset_A; std::vector<int64_t> offset_A;
@ -177,6 +179,9 @@ std::vector<StrideB> stride_B_host;
std::vector<StrideC> stride_C_host; std::vector<StrideC> stride_C_host;
std::vector<StrideD> stride_D_host; std::vector<StrideD> stride_D_host;
std::vector<ElementAccumulator> alpha_host;
std::vector<ElementAccumulator> beta_host;
// Device-side allocations // Device-side allocations
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes; cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
@ -197,7 +202,13 @@ cutlass::DeviceAllocation<StrideB> stride_B;
cutlass::DeviceAllocation<StrideC> stride_C; cutlass::DeviceAllocation<StrideC> stride_C;
cutlass::DeviceAllocation<StrideD> stride_D; cutlass::DeviceAllocation<StrideD> stride_D;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) // Note, this is an array of pointers to alpha and beta scaling values per group
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types /// Testbed utility types
@ -208,8 +219,8 @@ struct Options {
bool help = false; bool help = false;
float alpha = 1.0f; float alpha = FLT_MAX;
float beta = 0.0f; float beta = FLT_MAX;
int iterations = 10; int iterations = 10;
int m = 1024, n = 2048, k = 512, groups = 10; int m = 1024, n = 2048, k = 512, groups = 10;
std::string benchmark_path; std::string benchmark_path;
@ -230,8 +241,8 @@ struct Options {
cmd.get_cmd_line_argument("n", n); cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k); cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("groups", groups); cmd.get_cmd_line_argument("groups", groups);
cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX);
cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("beta", beta, FLT_MAX);
cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("benchmark", benchmark_path); cmd.get_cmd_line_argument("benchmark", benchmark_path);
@ -248,10 +259,7 @@ struct Options {
} }
void randomize_problems(cutlass::CommandLine &cmd) { void randomize_problems(cutlass::CommandLine &cmd) {
int cmd_line_m = -1; int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
int cmd_line_n = -1;
int cmd_line_k = -1;
cmd.get_cmd_line_argument("m", cmd_line_m); cmd.get_cmd_line_argument("m", cmd_line_m);
cmd.get_cmd_line_argument("n", cmd_line_n); cmd.get_cmd_line_argument("n", cmd_line_n);
cmd.get_cmd_line_argument("k", cmd_line_k); cmd.get_cmd_line_argument("k", cmd_line_k);
@ -259,19 +267,15 @@ struct Options {
problem_sizes_host.reserve(groups); problem_sizes_host.reserve(groups);
for (int i = groups; i > 0; i--) { for (int i = groups; i > 0; i--) {
int m = cmd_line_m; int m = cmd_line_m;
int n = cmd_line_n; int n = cmd_line_n;
int k = cmd_line_k; int k = cmd_line_k;
if (m < 1) { if (m < 1) {
m = ((rand() % 512) + 1); m = ((rand() % 512) + 1);
} }
if (n < 1) { if (n < 1) {
n = ((rand() % 512) + 1); n = ((rand() % 512) + 1);
} }
if (k < 1) { if (k < 1) {
k = alignment * ((rand() % 64) + 1); k = alignment * ((rand() % 64) + 1);
} }
@ -317,6 +321,7 @@ struct Options {
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
} }
} }
groups = static_cast<int>(problem_sizes_host.size());
return true; return true;
} }
@ -351,7 +356,9 @@ struct Options {
uint64_t fmas = uint64_t(); uint64_t fmas = uint64_t();
for (auto const & problem : problem_sizes_host) { for (auto const & problem : problem_sizes_host) {
fmas += cute::size(problem); fmas += static_cast<uint64_t>(get<0>(problem)) *
static_cast<uint64_t>(get<1>(problem)) *
static_cast<uint64_t>(get<2>(problem));
} }
// Two flops per multiply-add // Two flops per multiply-add
uint64_t flop = uint64_t(2) * uint64_t(fmas); uint64_t flop = uint64_t(2) * uint64_t(fmas);
@ -370,7 +377,7 @@ struct Result
bool passed = false; bool passed = false;
}; };
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation /// GEMM setup and evaluation
@ -435,6 +442,7 @@ void allocate(const Options &options) {
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{}))); stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})));
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{}))); stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})));
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{}))); stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})));
} }
block_A.reset(total_elements_A); block_A.reset(total_elements_A);
@ -442,6 +450,8 @@ void allocate(const Options &options) {
block_C.reset(total_elements_C); block_C.reset(total_elements_C);
block_D.reset(total_elements_D); block_D.reset(total_elements_D);
block_ref_D.reset(total_elements_D); block_ref_D.reset(total_elements_D);
block_alpha.reset(options.groups);
block_beta.reset(options.groups);
} }
/// Initialize operands to be used in the GEMM and reference GEMM /// Initialize operands to be used in the GEMM and reference GEMM
@ -460,12 +470,18 @@ void initialize(const Options &options) {
std::vector<ElementB *> ptr_B_host(options.groups); std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups); std::vector<ElementC *> ptr_C_host(options.groups);
std::vector<ElementC *> ptr_D_host(options.groups); std::vector<ElementC *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
for (int32_t i = 0; i < options.groups; ++i) { for (int32_t i = 0; i < options.groups; ++i) {
ptr_A_host.at(i) = block_A.get() + offset_A.at(i); ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
ptr_B_host.at(i) = block_B.get() + offset_B.at(i); ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
ptr_C_host.at(i) = block_C.get() + offset_C.at(i); ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
ptr_D_host.at(i) = block_D.get() + offset_D.at(i); ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
ptr_alpha_host.at(i) = block_alpha.get() + i;
ptr_beta_host.at(i) = block_beta.get() + i;
} }
ptr_A.reset(options.groups); ptr_A.reset(options.groups);
@ -492,13 +508,20 @@ void initialize(const Options &options) {
stride_D.reset(options.groups); stride_D.reset(options.groups);
stride_D.copy_from_host(stride_D_host.data()); stride_D.copy_from_host(stride_D_host.data());
alpha_device.reset(options.groups);
alpha_device.copy_from_host(ptr_alpha_host.data());
beta_device.reset(options.groups);
beta_device.copy_from_host(ptr_beta_host.data());
initialize_block(block_A, seed + 2023); initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022); initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021); initialize_block(block_C, seed + 2021);
block_alpha.copy_from_host(alpha_host.data());
block_beta.copy_from_host(beta_host.data());
} }
/// Populates a Gemm::Arguments structure from the given commandline options /// Populates a Gemm::Arguments structure from the given commandline options
typename Gemm::Arguments args_from_options(const Options &options) typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
{ {
cutlass::KernelHardwareInfo hw_info; cutlass::KernelHardwareInfo hw_info;
// Change device_id to another value if you are running on a machine with multiple GPUs and wish // Change device_id to another value if you are running on a machine with multiple GPUs and wish
@ -506,13 +529,36 @@ typename Gemm::Arguments args_from_options(const Options &options)
hw_info.device_id = 0; hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{ typename Gemm::EpilogueOutputOp::Params params;
cutlass::gemm::GemmUniversalMode::kGrouped, if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, params = typename Gemm::EpilogueOutputOp::Params(
{{options.alpha, options.beta}, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
hw_info }
}; else {
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
}
typename Gemm::Arguments arguments;
if (host_problem_shapes_available) {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
else {
arguments = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGrouped,
{options.groups, problem_sizes.get(), nullptr},
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
hw_info
};
}
return arguments; return arguments;
} }
@ -539,10 +585,10 @@ bool verify(const Options &options) {
// Launch device reference gemm kernel // Launch device reference gemm kernel
gemm_reference( gemm_reference(
{M, N, K}, {M, N, K},
ElementAccumulator(options.alpha), ElementAccumulator(alpha_host.at(i)),
ref_A, ref_A,
ref_B, ref_B,
ElementAccumulator(options.beta), ElementAccumulator(beta_host.at(i)),
ref_C, ref_C,
ref_D); ref_D);
@ -560,7 +606,7 @@ bool verify(const Options &options) {
/// Execute a given example GEMM computation /// Execute a given example GEMM computation
template <typename Gemm> template <typename Gemm>
int run(Options &options) int run(Options &options, bool host_problem_shapes_available = true)
{ {
allocate(options); allocate(options);
initialize(options); initialize(options);
@ -569,7 +615,7 @@ int run(Options &options)
Gemm gemm; Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
auto arguments = args_from_options(options); auto arguments = args_from_options(options, host_problem_shapes_available);
// Using the arguments, query for extra workspace required for matrix multiplication computation // Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments); size_t workspace_size = Gemm::get_workspace_size(arguments);
@ -612,12 +658,12 @@ int run(Options &options)
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
std::cout << " Problem Sizes: " << std::endl; std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
for (auto const & problem : options.problem_sizes_host) { for (int32_t i = 0; i < options.groups; ++i) {
std::cout << " " << problem << std::endl; std::cout << " " << options.problem_sizes_host.at(i);
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
} }
std::cout << " Groups : " << options.groups << std::endl; std::cout << " Groups : " << options.groups << std::endl;
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS : " << result.gflops << std::endl; std::cout << " GFLOPS : " << result.gflops << std::endl;
} }
@ -625,7 +671,7 @@ int run(Options &options)
return 0; return 0;
} }
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
@ -667,8 +713,9 @@ int main(int argc, char const **args) {
// Evaluate CUTLASS kernels // Evaluate CUTLASS kernels
// //
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
run<Gemm>(options); run<Gemm>(options);
run<Gemm>(options, false /*host_problem_shapes_available*/);
#endif #endif
return 0; return 0;

View File

@ -35,9 +35,15 @@ set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0)
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes
@ -49,8 +55,12 @@ cutlass_example_add_executable(
TEST_RANDOM_LARGE_GROUP TEST_RANDOM_LARGE_GROUP
TEST_EPILOGUE TEST_EPILOGUE
TEST_EPILOGUE_LARGE_GROUP TEST_EPILOGUE_LARGE_GROUP
TEST_EPILOGUE_OP
TEST_EPILOGUE_OP_LARGE_GROUP
TEST_FIXED TEST_FIXED
TEST_FIXED_LARGE_GROUP TEST_FIXED_LARGE_GROUP
TEST_SMALL
TEST_SMALL_LARGE_GROUP
TEST_RANDOM_PERF TEST_RANDOM_PERF
TEST_RANDOM_PERF_LARGE_GROUP TEST_RANDOM_PERF_LARGE_GROUP
) )

View File

@ -265,7 +265,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
asm volatile ( asm volatile (
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;"
:: "l"(smem_int64_desc), "r"(prob_shape[2])); :: "l"(smem_int64_desc), "r"(prob_shape[2]));
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
asm volatile ( asm volatile (
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;"
:: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4));

View File

@ -391,7 +391,7 @@ struct TiledMMA : MMA_Atom
} else { } else {
return cute::max(core_size, perm_size); return cute::max(core_size, perm_size);
} }
CUTE_GCC_UNREACHABLE; CUTE_GCC_UNREACHABLE;
} }

View File

@ -125,6 +125,9 @@ using CUTE_STL_NAMESPACE::invoke_result_t;
using CUTE_STL_NAMESPACE::common_type; using CUTE_STL_NAMESPACE::common_type;
using CUTE_STL_NAMESPACE::common_type_t; using CUTE_STL_NAMESPACE::common_type_t;
using CUTE_STL_NAMESPACE::remove_pointer;
using CUTE_STL_NAMESPACE::remove_pointer_t;
// <utility> // <utility>
using CUTE_STL_NAMESPACE::declval; using CUTE_STL_NAMESPACE::declval;

View File

@ -64,6 +64,10 @@
#endif #endif
#endif #endif
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3)))
#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
#endif
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
namespace cutlass { namespace cutlass {

View File

@ -80,6 +80,40 @@ struct TagToStrideB<layout::ColumnMajor> {
using tag = layout::ColumnMajor; using tag = layout::ColumnMajor;
}; };
// For each cutlass::layout *, provides its corresponding cute stride types, 64b by default
// Used by pointer array and grouped gemm
// Maps to modes [M, K, L]
template <>
struct TagToStrideA<layout::RowMajor *> {
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using type = UnderlyingType*;
using tag = layout::RowMajor;
};
// Maps to modes [M, K, L]
template <>
struct TagToStrideA<layout::ColumnMajor *> {
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using type = UnderlyingType*;
using tag = layout::ColumnMajor;
};
// Maps to modes [N, K, L]
template <>
struct TagToStrideB<layout::RowMajor *> {
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using type = UnderlyingType*;
using tag = layout::RowMajor;
};
// Maps to modes [N, K, L]
template <>
struct TagToStrideB<layout::ColumnMajor *> {
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using type = UnderlyingType*;
using tag = layout::ColumnMajor;
};
// Maps to modes [M, N, L] // Maps to modes [M, N, L]
template <class LayoutTag> template <class LayoutTag>
struct TagToStrideC : TagToStrideA<LayoutTag> { }; struct TagToStrideC : TagToStrideA<LayoutTag> { };
@ -101,7 +135,7 @@ template<int ModeIndex, class Stride>
constexpr bool constexpr bool
is_major(Stride = {}) { is_major(Stride = {}) {
// Account for stride types with and without batch mode and batch modes with static zero stride // Account for stride types with and without batch mode and batch modes with static zero stride
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(Stride{})))>::value; return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(cute::remove_pointer_t<Stride>{})))>::value;
} }
// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices // Note : This method can be used for deducing the Layout Tag of A, C, D Matrices

View File

@ -268,7 +268,7 @@ struct Sm90TmaBuilderImpl {
// Passing void C disables source load + smem allocation // Passing void C disables source load + smem allocation
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>; using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>; using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
@ -434,8 +434,7 @@ struct CollectiveBuilder<
Schedule, Schedule,
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>, fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
cute::enable_if_t<cute::is_same_v<Schedule, NoSmemWarpSpecialized> || cute::enable_if_t<cute::is_same_v<Schedule, NoSmemWarpSpecialized> ||
cute::is_same_v<Schedule, NoSmemWarpSpecializedArray> || cute::is_same_v<Schedule, PtrArrayNoSmemWarpSpecialized> >> {
cute::is_same_v<Schedule, NoSmemWarpSpecializedGroup> >> {
// Passing void C disables source load // Passing void C disables source load
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>, using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,

View File

@ -63,6 +63,7 @@ public:
// Type Aliases // Type Aliases
// //
using EpilogueSchedule = EpilogueSchedule_; using EpilogueSchedule = EpilogueSchedule_;
using DispatchPolicy = EpilogueSchedule_;
// derived types of output thread level operator // derived types of output thread level operator
using ThreadEpilogueOp = ThreadEpilogueOp_; using ThreadEpilogueOp = ThreadEpilogueOp_;

View File

@ -73,12 +73,10 @@ public:
using ElementScalar = ElementCompute; using ElementScalar = ElementCompute;
using ElementC = typename ThreadEpilogueOp::ElementC; using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_; using StrideC = StrideC_;
using UnderlyingStrideC = cute::remove_pointer_t<StrideC>;
using ElementD = typename ThreadEpilogueOp::ElementD; using ElementD = typename ThreadEpilogueOp::ElementD;
using StrideD = StrideD_; using StrideD = StrideD_;
using StridesC = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>, using UnderlyingStrideD = cute::remove_pointer_t<StrideD>;
StrideC const*, StrideC>;
using StridesD = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
StrideD const*, StrideD>;
using GmemTiledCopyC = void; using GmemTiledCopyC = void;
using GmemTiledCopyD = void; using GmemTiledCopyD = void;
@ -86,10 +84,9 @@ public:
static const int kOutputAlignment = ThreadEpilogueOp::kCount; static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type; using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup> || static_assert(cute::is_same_v<EpilogueSchedule, PtrArrayNoSmemWarpSpecialized>, "Incompatible epilogue schedule.");
cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedArray>, "Incompatible epilogue schedule."); static_assert(rank(UnderlyingStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(UnderlyingStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { }; struct SharedStorage { };
@ -97,9 +94,9 @@ public:
struct Arguments { struct Arguments {
typename ThreadEpilogueOp::Params thread{}; typename ThreadEpilogueOp::Params thread{};
ElementC const** ptr_C = nullptr; ElementC const** ptr_C = nullptr;
StridesC dC{}; StrideC dC{};
ElementD** ptr_D = nullptr; ElementD** ptr_D = nullptr;
StridesD dD{}; StrideD dD{};
}; };
// Device side epilogue params // Device side epilogue params
@ -140,12 +137,13 @@ public:
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
DefaultEpilogueArray(Params const& params_) DefaultEpilogueArray(Params const& params_)
: params(params_), epilogue_op(params_.thread) { } : params(params_) { }
CUTLASS_DEVICE CUTLASS_DEVICE
bool bool
is_source_needed() { is_source_needed() {
return epilogue_op.is_source_needed(); // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
return true;
} }
template< template<
@ -185,10 +183,23 @@ public:
// Slice to get the tile this CTA is responsible for // Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
StrideC stride_c; // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
StrideD stride_d; // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
if constexpr (cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>) { // we get the correct alpha/beta values for the current batch/group using group index.
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC[l_coord]); ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord);
if (epilogue_op.is_source_needed() && params.dC == nullptr) {
// Beta value is non-zero while pointer to C is a nullptr
assert(0);
}
UnderlyingStrideC stride_c;
UnderlyingStrideD stride_d;
if constexpr (!cute::is_same_v<UnderlyingStrideC, StrideC>) {
// If grouped gemm
if (epilogue_op.is_source_needed()) {
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC[l_coord]);
}
stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD[l_coord]); stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD[l_coord]);
} }
else { else {
@ -197,7 +208,11 @@ public:
} }
// Represent the full output tensor // Represent the full output tensor
Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C[l_coord]), make_shape(M,N,mock_L), stride_c); // (m,n,l) ElementC const* ptr_C_l = nullptr;
if (epilogue_op.is_source_needed()) {
ptr_C_l = params.ptr_C[l_coord];
}
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l)
Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l)
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
@ -242,7 +257,6 @@ public:
private: private:
Params params; Params params;
ThreadEpilogueOp epilogue_op;
}; };
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -148,12 +148,12 @@ private:
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{});
using EmptyType = cute::tuple<>; using EmptyType = cute::tuple<>;
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC), using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
array_aligned<SmemElementC, size(SmemLayoutC{}), SmemAlignmentC>, array_aligned<SmemElementC, size(SmemLayoutC{}), SmemAlignmentC>,
EmptyType>; EmptyType>;
using SmemDStorage = cute::conditional_t<is_destination_supported, using SmemDStorage = cute::conditional_t<is_destination_supported,
array_aligned<SmemElementD, size(SmemLayoutD{}), SmemAlignmentD>, array_aligned<SmemElementD, size(SmemLayoutD{}), SmemAlignmentD>,
EmptyType>; EmptyType>;
@ -189,6 +189,7 @@ public:
struct SharedStorage { struct SharedStorage {
using TensorStorage = TensorStorageImpl; using TensorStorage = TensorStorageImpl;
TensorStorage tensors; TensorStorage tensors;
using PipelineStorage = typename LoadPipeline::SharedStorage; using PipelineStorage = typename LoadPipeline::SharedStorage;
@ -249,12 +250,12 @@ public:
Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC)); Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC));
tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0)); tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0));
} }
typename Params::TMA_D tma_store_d; typename Params::TMA_D tma_store_d;
if constexpr (is_destination_supported) { if constexpr (is_destination_supported) {
Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD)); Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutD{}(_,_,0)); tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutD{}(_,_,0));
} }
return { return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
@ -385,13 +386,13 @@ public:
// Apply epilogue subtile, get matching smem tensor // Apply epilogue subtile, get matching smem tensor
SmemElementC* ptr_sC = nullptr; SmemElementC* ptr_sC = nullptr;
if constexpr (is_source_supported) { if constexpr (is_source_supported) {
if constexpr (ReuseSmemC) { if constexpr (ReuseSmemC) {
ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D().data()); ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D().data());
} else { } else {
ptr_sC = shared_tensors.smem_C().data(); ptr_sC = shared_tensors.smem_C().data();
} }
} }
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
@ -559,7 +560,7 @@ public:
// Vectorized fragment view // Vectorized fragment view
constexpr int FragmentSize = DispatchPolicy::FragmentSize; constexpr int FragmentSize = DispatchPolicy::FragmentSize;
Tensor tRS_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(tRS_rAcc); Tensor tRS_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(tRS_rAcc);
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD); Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly");
// (t)hread-partition for (s)mem to (r)egister copy (tSR_) // (t)hread-partition for (s)mem to (r)egister copy (tSR_)

View File

@ -46,8 +46,7 @@ namespace cutlass::epilogue {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
struct NoSmemWarpSpecialized {}; struct NoSmemWarpSpecialized {};
struct NoSmemWarpSpecializedArray {}; struct PtrArrayNoSmemWarpSpecialized {};
struct NoSmemWarpSpecializedGroup {};
struct TmaWarpSpecialized {}; struct TmaWarpSpecialized {};
struct TmaWarpSpecializedCooperative {}; struct TmaWarpSpecializedCooperative {};
// DEPRECATED schedules, will be removed in next release // DEPRECATED schedules, will be removed in next release

View File

@ -1247,6 +1247,7 @@ struct FusionCallbacks<
}; };
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail { namespace detail {
template <class FusionOpOrCallbacks, class = cute::void_t<>> template <class FusionOpOrCallbacks, class = cute::void_t<>>
struct get_element_aux { struct get_element_aux {
@ -1257,7 +1258,7 @@ template <class FusionOpOrCallbacks>
struct get_element_aux<FusionOpOrCallbacks, cute::void_t<typename FusionOpOrCallbacks::ElementAux>> { struct get_element_aux<FusionOpOrCallbacks, cute::void_t<typename FusionOpOrCallbacks::ElementAux>> {
using type = typename FusionOpOrCallbacks::ElementAux; using type = typename FusionOpOrCallbacks::ElementAux;
}; };
template <class NodeOp, class... ChildOps> template <class NodeOp, class... ChildOps>
struct get_element_aux<Sm90TreeVisitor<NodeOp, ChildOps...>, cute::void_t<>> { struct get_element_aux<Sm90TreeVisitor<NodeOp, ChildOps...>, cute::void_t<>> {
using type = typename get_element_aux<NodeOp>::type; using type = typename get_element_aux<NodeOp>::type;
@ -1270,7 +1271,7 @@ struct get_element_aux<FusionCallbacks<Ts...>, cute::void_t<typename FusionCallb
public: public:
using type = typename get_element_aux<Operation>::type; using type = typename get_element_aux<Operation>::type;
}; };
} } // namespace cutlass:epilogue::fusion::detail
template <class Callbacks> template <class Callbacks>
using get_element_aux_t = typename detail::get_element_aux<Callbacks>::type; using get_element_aux_t = typename detail::get_element_aux<Callbacks>::type;

View File

@ -88,43 +88,72 @@ public:
/// Host-constructable parameters structure /// Host-constructable parameters structure
struct Params struct Params
{ {
ElementCompute alpha; ///< scales accumulators ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
ElementCompute const* const* alpha_ptr_array; ///< array of pointers to accumulator scalar per group/batch
ElementCompute const* const* beta_ptr_array; ///< array of pointers to source scalar per group/batch
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params(): Params():
alpha(ElementCompute(1)), alpha(ElementCompute(1)),
beta(ElementCompute(0)), beta(ElementCompute(0)),
alpha_ptr(nullptr), alpha_ptr(nullptr),
beta_ptr(nullptr) { } beta_ptr(nullptr),
alpha_ptr_array(nullptr),
beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params( Params(
ElementCompute alpha, ElementCompute alpha,
ElementCompute beta ElementCompute beta
): ):
alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } alpha(alpha), beta(beta),
alpha_ptr(nullptr), beta_ptr(nullptr),
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params( Params(
ElementCompute alpha ElementCompute alpha
): ):
alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { } alpha(alpha), beta(0),
alpha_ptr(nullptr), beta_ptr(nullptr),
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params( Params(
ElementCompute const *alpha_ptr, ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr ElementCompute const *beta_ptr
): ):
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } alpha(0), beta(0),
alpha_ptr(alpha_ptr), beta_ptr(beta_ptr),
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params( Params(
ElementCompute const *alpha_ptr ElementCompute const *alpha_ptr
): ):
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { } alpha(0), beta(0),
alpha_ptr(alpha_ptr), beta_ptr(nullptr),
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute const* const* alpha_ptr_array,
ElementCompute const* const* beta_ptr_array
):
alpha(0), beta(0),
alpha_ptr(nullptr), beta_ptr(nullptr),
alpha_ptr_array(alpha_ptr_array), beta_ptr_array(beta_ptr_array) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute const* const* alpha_ptr_array
):
alpha(0), beta(0),
alpha_ptr(nullptr), beta_ptr(nullptr),
alpha_ptr_array(alpha_ptr_array), beta_ptr_array(nullptr) { }
}; };
private: private:
@ -140,9 +169,25 @@ public:
/// Constructs the function object, possibly loading from pointers in host memory /// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
LinearCombination(Params const &params) { LinearCombination(Params const &params, int group_idx = 0) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); if (params.alpha_ptr_array != nullptr && params.alpha_ptr_array[group_idx] != nullptr) {
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); alpha_ = *(params.alpha_ptr_array[group_idx]);
}
else if (params.alpha_ptr != nullptr) {
alpha_ = *params.alpha_ptr;
}
else {
alpha_ = params.alpha;
}
if (params.beta_ptr_array != nullptr && params.beta_ptr_array[group_idx] != nullptr) {
beta_ = *(params.beta_ptr_array[group_idx]);
}
else if (params.beta_ptr != nullptr) {
beta_ = *params.beta_ptr;
}
else {
beta_ = params.beta;
}
} }
/// Returns true if source is needed /// Returns true if source is needed

View File

@ -185,8 +185,7 @@ struct CollectiveBuilder<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> || (cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> || cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> || cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> || cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()> not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>
> { > {
static_assert(is_static<TileShape_MNK>::value); static_assert(is_static<TileShape_MNK>::value);
@ -197,8 +196,7 @@ struct CollectiveBuilder<
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(), static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n"); "Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> || static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>(); static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
@ -515,8 +513,7 @@ struct CollectiveBuilder<
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> || cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> || cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> || cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> || cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>>
> { > {
static_assert(is_static<TileShape_MNK>::value); static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value); static_assert(is_static<ClusterShape_MNK>::value);
@ -534,8 +531,7 @@ struct CollectiveBuilder<
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>(); static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> || static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> || using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
IsArrayOfPointersGemm, IsArrayOfPointersGemm,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>; Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;

View File

@ -93,8 +93,10 @@ struct CollectiveMma<
using TileShape = TileShape_; using TileShape = TileShape_;
using ElementA = ElementA_; using ElementA = ElementA_;
using StrideA = StrideA_; using StrideA = StrideA_;
using UnderlyingStrideA = cute::remove_pointer_t<StrideA>;
using ElementB = ElementB_; using ElementB = ElementB_;
using StrideB = StrideB_; using StrideB = StrideB_;
using UnderlyingStrideB = cute::remove_pointer_t<StrideB>;
using TiledMma = TiledMma_; using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC; using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_; using GmemTiledCopyA = GmemTiledCopyA_;
@ -149,14 +151,14 @@ struct CollectiveMma<
// Assumption: StrideA is congruent with Problem_MK // Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy( using TMA_A = decltype(make_tma_copy(
GmemTiledCopyA{}, GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(UnderlyingStrideA{}, int32_t(0)), UnderlyingStrideA{}),
SmemLayoutA{}(_,_,cute::Int<0>{}), SmemLayoutA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK // Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy( using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{}, GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(UnderlyingStrideB{}, int32_t(0)), UnderlyingStrideB{}),
SmemLayoutB{}(_,_,cute::Int<0>{}), SmemLayoutB{}(_,_,cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
@ -179,16 +181,14 @@ struct CollectiveMma<
using TensorMapStorage = typename SharedStorage::TensorMapStorage; using TensorMapStorage = typename SharedStorage::TensorMapStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage; using PipelineStorage = typename SharedStorage::PipelineStorage;
static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>; static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<UnderlyingStrideA, StrideA>;
using StridesA = cute::conditional_t<IsGroupedGemmKernel, StrideA const*, StrideA>;
using StridesB = cute::conditional_t<IsGroupedGemmKernel, StrideB const*, StrideB>;
// Host side kernel arguments // Host side kernel arguments
struct Arguments { struct Arguments {
ElementA const** ptr_A; ElementA const** ptr_A;
StridesA dA; StrideA dA;
ElementB const** ptr_B; ElementB const** ptr_B;
StridesB dB; StrideB dB;
}; };
// Device side kernel params // Device side kernel params
@ -197,9 +197,9 @@ struct CollectiveMma<
TMA_B tma_load_b; TMA_B tma_load_b;
void* tensormaps; void* tensormaps;
InternalElementA const** ptr_A; InternalElementA const** ptr_A;
StridesA dA; StrideA dA;
InternalElementB const** ptr_B; InternalElementB const** ptr_B;
StridesB dB; StrideB dB;
}; };
// //
@ -212,30 +212,36 @@ struct CollectiveMma<
ProblemShape problem_shapes, ProblemShape problem_shapes,
Arguments const& args, Arguments const& args,
void* workspace) { void* workspace) {
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(0), 1); // These will be replaced with correct values before the initial tma load.
auto [M,N,K,L] = problem_shape_MNKL; auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1));
auto init_M = get<0>(init_shape);
auto init_N = get<1>(init_shape);
auto init_K = get<2>(init_shape);
// Batches/Groups are managed by using appropriate pointers to input matrices
const uint32_t mock_L = 1; const uint32_t mock_L = 1;
// These tensor pointers are only used to create tensormap/tma desc.
// This address to the tensor will be replaced with correct address before the initial tma load
InternalElementA const* ptr_A_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_A); InternalElementA const* ptr_A_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_A);
InternalElementB const* ptr_B_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_B); InternalElementB const* ptr_B_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_B);
cudaError_t cuda_error = cudaGetLastError(); // clear previous error
StrideA stride_a; UnderlyingStrideA stride_a;
StrideB stride_b; UnderlyingStrideB stride_b;
if constexpr (IsGroupedGemmKernel) { if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless // Strides for Grouped Gemm will be replaced prior to the first access regardless.
stride_a = StrideA{}; stride_a = UnderlyingStrideA{};
stride_b = StrideB{}; stride_b = UnderlyingStrideB{};
} }
else { else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
init_M = get<0>(problem_shape_MNK);
init_N = get<1>(problem_shape_MNK);
init_K = get<2>(problem_shape_MNK);
stride_a = args.dA; stride_a = args.dA;
stride_b = args.dB; stride_b = args.dB;
} }
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), stride_a)); Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,mock_L), stride_a));
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), stride_b)); Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,mock_L), stride_b));
TMA_A tma_load_a = make_tma_copy( TMA_A tma_load_a = make_tma_copy(
GmemTiledCopyA{}, GmemTiledCopyA{},
tensor_a, tensor_a,
@ -287,12 +293,14 @@ struct CollectiveMma<
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value; constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
bool implementable = true; bool implementable = true;
// Check alignment for all problem sizes if (problem_shapes.is_host_problem_shape_available()) {
for (int i = 0; i < problem_shapes.groups(); i++) { // Check alignment for all problem sizes
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); for (int i = 0; i < problem_shapes.groups(); i++) {
auto [M,N,K,L] = problem_shape_MNKL; auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{}); auto [M,N,K,L] = problem_shape_MNKL;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{}); implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), UnderlyingStrideA{});
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), UnderlyingStrideB{});
}
} }
if (!implementable) { if (!implementable) {
@ -676,6 +684,14 @@ struct CollectiveMma<
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b,
prob_shape_B, prob_stride_B); prob_shape_B, prob_stride_B);
// Convert strides to byte strides
for (uint64_t& stride : prob_stride_A) {
stride = (stride * sizeof_bits_v<InternalElementA>) / 8;
}
for (uint64_t& stride : prob_stride_B) {
stride = (stride * sizeof_bits_v<InternalElementB>) / 8;
}
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A, cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A,
prob_shape_A, prob_shape_A,
prob_stride_A); prob_stride_A);

View File

@ -53,8 +53,7 @@ struct KernelTma { };
struct KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperative { }; struct KernelTmaWarpSpecializedCooperative { };
struct KernelArrayTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedCooperative { };
struct KernelGroupTmaWarpSpecializedCooperative { };
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -67,8 +66,7 @@ struct KernelGroupTmaWarpSpecializedCooperative { };
struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { }; struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { };
struct KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelArrayTmaWarpSpecializedCooperative { }; struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { };
struct KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum : KernelGroupTmaWarpSpecializedCooperative { };
// Policies to opt into mixed type GEMMs // Policies to opt into mixed type GEMMs
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
@ -233,7 +231,7 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8
template< template<
int Stages_, int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>, class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelGroupTmaWarpSpecializedCooperative class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
> >
struct MainloopSm90ArrayTmaGmmaWarpSpecialized { struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
constexpr static int Stages = Stages_; constexpr static int Stages = Stages_;
@ -241,8 +239,7 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
using ArchTag = arch::Sm90; using ArchTag = arch::Sm90;
using Schedule = KernelSchedule; using Schedule = KernelSchedule;
static_assert( static_assert(
cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, KernelSchedule> || cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule>,
cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>,
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies"); "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
}; };

View File

@ -71,6 +71,12 @@ struct GroupProblemShape {
get_host_problem_shape(int32_t group_idx) const { get_host_problem_shape(int32_t group_idx) const {
return host_problem_shapes[group_idx]; return host_problem_shapes[group_idx];
} }
CUTLASS_HOST_DEVICE
bool
is_host_problem_shape_available() {
return host_problem_shapes != nullptr;
}
}; };
template <class ProblemShape_> template <class ProblemShape_>
@ -104,6 +110,12 @@ public:
get_host_problem_shape(int32_t /* unused */ = 0) const { get_host_problem_shape(int32_t /* unused */ = 0) const {
return problem_shape_; return problem_shape_;
} }
CUTLASS_HOST_DEVICE
bool
is_host_problem_shape_available() {
return true;
}
private: private:
UnderlyingProblemShape problem_shape_{}; UnderlyingProblemShape problem_shape_{};
}; };

View File

@ -62,8 +62,7 @@ class GemmUniversal<
CollectiveMainloop_, CollectiveMainloop_,
CollectiveEpilogue_, CollectiveEpilogue_,
TileScheduler_, TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule> || cute::enable_if_t<cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>>
cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>>
> >
{ {
public: public:
@ -80,7 +79,9 @@ public:
using ArchTag = typename CollectiveMainloop::ArchTag; using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA; using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA; using StrideA = typename CollectiveMainloop::StrideA;
using UnderlyingStrideA = typename CollectiveMainloop::UnderlyingStrideA;
using ElementB = typename CollectiveMainloop::ElementB; using ElementB = typename CollectiveMainloop::ElementB;
using UnderlyingStrideB = typename CollectiveMainloop::UnderlyingStrideB;
using StrideB = typename CollectiveMainloop::StrideB; using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using Schedule = typename DispatchPolicy::Schedule; using Schedule = typename DispatchPolicy::Schedule;
@ -93,8 +94,10 @@ public:
using CollectiveEpilogue = CollectiveEpilogue_; using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC; using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC; using StrideC = typename CollectiveEpilogue::StrideC;
using UnderlyingStrideC = typename CollectiveEpilogue::UnderlyingStrideC;
using ElementD = typename CollectiveEpilogue::ElementD; using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD; using StrideD = typename CollectiveEpilogue::StrideD;
using UnderlyingStrideD = typename CollectiveEpilogue::UnderlyingStrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params; using EpilogueParams = typename CollectiveEpilogue::Params;
@ -102,7 +105,7 @@ public:
static_assert(cute::is_void_v<TileScheduler_>, static_assert(cute::is_void_v<TileScheduler_>,
"Ptr-Array Cooperative and Grouped Gemm Cooperative kernel only supports the default scheduler."); "Ptr-Array Cooperative and Grouped Gemm Cooperative kernel only supports the default scheduler.");
static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, Schedule>; static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<UnderlyingStrideA, StrideA>;
using TileScheduler = cute::conditional_t<IsGroupedGemmKernel, using TileScheduler = cute::conditional_t<IsGroupedGemmKernel,
typename detail::TileSchedulerSelector< typename detail::TileSchedulerSelector<
@ -204,7 +207,7 @@ public:
void* scheduler_workspace = workspace_ptr; void* scheduler_workspace = workspace_ptr;
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>( workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups); args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset; void* epilogue_workspace = workspace_ptr + workspace_offset;
@ -244,14 +247,11 @@ public:
bool bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
bool implementable = true; bool implementable = true;
if constexpr (cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, Schedule>) { if constexpr (IsGroupedGemmKernel) {
implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
} else if constexpr (IsGroupedGemmKernel) {
// Group GEMM currently only supports rank-3 problem shapes // Group GEMM currently only supports rank-3 problem shapes
implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3);
} } else {
else { implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
implementable = false;
} }
if (!implementable) { if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n");
@ -269,7 +269,7 @@ public:
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
workspace_size += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>( workspace_size += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
@ -297,9 +297,9 @@ public:
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
status = TileScheduler::template initialize_workspace<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>( status = TileScheduler::template initialize_workspace<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>( workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) { if (status != Status::kSuccess) {
return status; return status;
@ -350,23 +350,20 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else
return;
}
#endif
// Preconditions // Preconditions
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
static_assert(size<0>(TileShape{}) >= 128, static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(UnderlyingStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(UnderlyingStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(UnderlyingStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(UnderlyingStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
enum class WarpGroupRole { enum class WarpGroupRole {
@ -441,8 +438,6 @@ public:
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// Purpose of maintaining this pipeline state is to make sure TMA loads have finished before doing descriptor updates
typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state;
// For the DMA Load (producer) we start with an opposite phase // For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty // i.e., we skip all waits since we know that the buffer is indeed empty
@ -554,7 +549,8 @@ public:
shared_storage.tensors.mainloop shared_storage.tensors.mainloop
); );
// Update starting pipeline state for the next tile // Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(work_k_tile_count); // Wait for the last TMA stage to complete loading, before issuing tensormap updates
mainloop_pipe_producer_state.advance(work_k_tile_count - 1);
// Signal for the epilogue load warp to begin // Signal for the epilogue load warp to begin
if (do_load_order_arrive) { if (do_load_order_arrive) {
@ -570,8 +566,10 @@ public:
if constexpr (IsGroupedGemmKernel) { if constexpr (IsGroupedGemmKernel) {
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(next_batch), Int<1>{}); problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(next_batch), Int<1>{});
} }
// Wait for the last TMA stage to complete loading, before issuing tensormap updates // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
mainloop_pipe_tma_consumer_state.advance(work_k_tile_count-1); // Since this state is waiting for loads to finish, it must start in the inverted phase.
typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state =
{mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()};
mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state); mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state);
collective_mainloop.tensormaps_perform_update( collective_mainloop.tensormaps_perform_update(
shared_storage.tensormaps.mainloop, shared_storage.tensormaps.mainloop,
@ -585,13 +583,9 @@ public:
// Entire warp must do this (ie its aligned) // Entire warp must do this (ie its aligned)
collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps);
curr_batch = next_batch; curr_batch = next_batch;
// Advance the TMA consumer state for the last remaining stage that was being waited for above
mainloop_pipe_tma_consumer_state.advance(1);
}
else if (work_tile_info.is_valid()) { // case where batch/group didn't change between tiles
// Advance the TMA consumer state for all the stages to be in sync
mainloop_pipe_tma_consumer_state.advance(work_k_tile_count);
} }
// Advance the producer state for the last remaining stage that was being waited for above
mainloop_pipe_producer_state.advance(1);
} // Scheduler work fetch loop } // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon // Make sure all Consumer Warp Groups have been waited upon
@ -720,6 +714,7 @@ public:
); );
} }
} // Consumer Warp Groups End } // Consumer Warp Groups End
#endif
} }
private: private:

View File

@ -211,13 +211,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else
return;
}
#endif
// Preconditions // Preconditions
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -311,6 +308,7 @@ public:
thread_idx, thread_idx,
smem_buf smem_buf
); );
#endif
} }
}; };

View File

@ -219,13 +219,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else
return;
}
#endif
enum class WarpGroupRole { enum class WarpGroupRole {
Producer = 0, Producer = 0,
@ -435,6 +432,7 @@ public:
epi_store_pipe_producer_state_next epi_store_pipe_producer_state_next
); );
} }
#endif
} }
}; };

View File

@ -298,13 +298,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else
return;
}
#endif
// Preconditions // Preconditions
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
@ -610,6 +607,7 @@ public:
); );
} }
} // Consumer Warp Groups End } // Consumer Warp Groups End
#endif
} }
private: private:

View File

@ -296,13 +296,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else
return;
}
#endif
// Preconditions // Preconditions
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -612,6 +609,7 @@ public:
work_tile_info = scheduler.get_current_work(); work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop } // Scheduler work fetch loop
} // Consumer Warp Groups End } // Consumer Warp Groups End
#endif
} }
}; };

View File

@ -223,13 +223,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else
return;
}
#endif
enum class WarpGroupRole { enum class WarpGroupRole {
Producer = 0, Producer = 0,
@ -409,6 +406,7 @@ public:
shared_storage.tensors.epilogue shared_storage.tensors.epilogue
); );
} }
#endif
} }
}; };

View File

@ -250,13 +250,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else
return;
}
#endif
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -493,6 +490,7 @@ public:
); );
} }
} // Consumer Warp Groups End } // Consumer Warp Groups End
#endif
} }
private: private:

View File

@ -257,13 +257,10 @@ public:
using namespace cute; using namespace cute;
using X = Underscore; using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); #else
return;
}
#endif
// Preconditions // Preconditions
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
@ -509,6 +506,7 @@ public:
work_tile_info = scheduler.get_current_work(); work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop } // Scheduler work fetch loop
} // Consumer Warp Groups End } // Consumer Warp Groups End
#endif
} }
}; };

View File

@ -55,7 +55,7 @@ private:
// Tracking current group, its starting linear idx and total tiles // Tracking current group, its starting linear idx and total tiles
struct GroupInfo { struct GroupInfo {
uint64_t group = 0; int group_idx = 0;
uint64_t start_linear_idx = 0; uint64_t start_linear_idx = 0;
uint64_t total_tiles = 0; uint64_t total_tiles = 0;
} current_group_info_; } current_group_info_;
@ -115,7 +115,7 @@ public:
GroupProblemShape problem_shapes, GroupProblemShape problem_shapes,
TileShape tile_shape, TileShape tile_shape,
ClusterShape cluster_shape, ClusterShape cluster_shape,
[[maybe_unused]] KernelHardwareInfo const& hw_info, KernelHardwareInfo const& hw_info,
Arguments const& arguments, Arguments const& arguments,
[[maybe_unused]] void* workspace=nullptr, [[maybe_unused]] void* workspace=nullptr,
[[maybe_unused]] const uint32_t epilogue_subtile = 1) { [[maybe_unused]] const uint32_t epilogue_subtile = 1) {
@ -126,14 +126,16 @@ public:
dim3 problem_blocks = get_tiled_cta_shape_mnl( dim3 problem_blocks = get_tiled_cta_shape_mnl(
problem_shapes.groups(), problem_shapes.groups(),
reinterpret_cast<ProblemShape const*>(problem_shapes.host_problem_shapes), problem_shapes,
hw_info,
tile_shape, cluster_shape); tile_shape, cluster_shape);
Params params; Params params;
params.initialize( params.initialize(
problem_blocks, problem_blocks,
problem_shapes.groups(), problem_shapes.groups(),
reinterpret_cast<ProblemShape*>(problem_shapes.problem_shapes), problem_shapes.problem_shapes,
problem_shapes.host_problem_shapes,
to_gemm_coord(tile_shape), to_gemm_coord(tile_shape),
to_gemm_coord(cluster_shape), to_gemm_coord(cluster_shape),
hw_info, hw_info,
@ -144,6 +146,64 @@ public:
return params; return params;
} }
// Given the inputs, computes the physical grid we should launch.
template<class TileShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
GroupProblemShape problem_shapes,
TileShape tile_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments,
bool truncate_by_problem_size=true) {
dim3 problem_blocks = get_tiled_cta_shape_mnl(
problem_shapes.groups(),
problem_shapes,
hw_info,
tile_shape, cluster_shape);
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order,
/* truncate_by_problem_size = */true
);
}
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_cta_shape_mnl(int groups, GroupProblemShape problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) {
uint32_t total_ctas = 0;
uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here
// If host problem shapes are not provided.
if (!problem_shapes.is_host_problem_shape_available()) {
total_ctas = hw_info.sm_count;
}
// If host problem shapes are provided, make a better decision about possibility to launch smaller grid.
else {
for (int group = 0; group < groups; group++) {
auto ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes.get_host_problem_shape(group)), cute::shape<0>(cta_shape)));
auto ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes.get_host_problem_shape(group)), cute::shape<1>(cta_shape)));
auto problem_blocks_m = round_up(ctas_along_m, cute::get<0>(cluster_shape));
auto problem_blocks_n = round_up(ctas_along_n, cute::get<1>(cluster_shape));
total_ctas += problem_blocks_m * problem_blocks_n;
}
}
return Params::get_tiled_cta_shape_mnl(
to_gemm_coord(cluster_shape),
total_ctas, cta_in_N_dim
);
}
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
static bool static bool
can_implement(Arguments const& args) { can_implement(Arguments const& args) {
@ -156,7 +216,7 @@ public:
// MSVC requires protecting use of CUDA-specific nonstandard syntax, // MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__. // like blockIdx and gridDim, with __CUDA_ARCH__.
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
if (params_.raster_order_ == RasterOrder::AlongN) { if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x);
} }
else { else {
@ -165,9 +225,19 @@ public:
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), params_.cta_shape_.m())); uint64_t ctas_along_m, ctas_along_n;
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), params_.cta_shape_.n())); if (is_tuple<decltype(cute::shape<0>(params_.problem_shapes_[0]))>::value ||
current_group_info_.total_tiles = cta_m * cta_n; is_tuple<decltype(cute::shape<1>(params_.problem_shapes_[0]))>::value) {
ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.m()));
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.n()));
}
else {
ctas_along_m = scheduler_params.divmod_cta_shape_m_.divide(cute::shape<0>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_m_.divisor - 1);
ctas_along_n = scheduler_params.divmod_cta_shape_n_.divide(cute::shape<1>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_n_.divisor - 1);
}
auto problem_blocks_m = round_up(ctas_along_m, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.m());
auto problem_blocks_n = round_up(ctas_along_n, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.n());
current_group_info_.total_tiles = problem_blocks_m * problem_blocks_n;
#else #else
CUTLASS_ASSERT(false && "This line should never be reached"); CUTLASS_ASSERT(false && "This line should never be reached");
#endif #endif
@ -182,24 +252,22 @@ public:
CUTLASS_DEVICE CUTLASS_DEVICE
WorkTileInfo WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx) { get_current_work_for_linear_idx(uint64_t linear_idx) {
if (linear_idx >= scheduler_params.blocks_per_problem_) { if (scheduler_params.pre_processed_problem_shapes && linear_idx >= scheduler_params.blocks_across_problem_) {
return WorkTileInfo::invalid_work_tile(); return WorkTileInfo::invalid_work_tile();
} }
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(linear_idx); return get_work_idx_m_and_n(linear_idx,
current_group_info_,
auto [work_idx_m, work_idx_n, new_group_info, valid_tile] = get_work_idx_m_and_n(blk_per_grid_dim, scheduler_params.groups_,
current_group_info_, scheduler_params.problem_shapes_,
scheduler_params.groups_, scheduler_params.cta_shape_,
scheduler_params.problem_shapes_, scheduler_params.cluster_shape_,
scheduler_params.cta_shape_, scheduler_params.divmod_cluster_shape_major_,
scheduler_params.divmod_cluster_shape_major_, scheduler_params.divmod_cluster_shape_minor_,
scheduler_params.divmod_cluster_shape_minor_, scheduler_params.divmod_cta_shape_m_,
scheduler_params.log_swizzle_size_, scheduler_params.divmod_cta_shape_n_,
scheduler_params.raster_order_); scheduler_params.log_swizzle_size_,
scheduler_params.raster_order_);
current_group_info_ = new_group_info;
return {work_idx_m, work_idx_n, static_cast<int>(current_group_info_.group), valid_tile};
} }
CUTLASS_DEVICE CUTLASS_DEVICE
@ -208,34 +276,62 @@ public:
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
} }
// get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle // get work_idx_m, work_idx_n from linear_idx while applying swizzle
static CUTLASS_DEVICE static CUTLASS_DEVICE
cute::tuple<int32_t, int32_t, struct GroupInfo, bool> WorkTileInfo
get_work_idx_m_and_n( get_work_idx_m_and_n(
uint64_t blk_per_grid_dim, uint64_t linear_idx,
struct GroupInfo group_info, struct GroupInfo& group_info,
int32_t total_problem_groups, int32_t total_problem_groups,
ProblemShape* problem_shapes, ProblemShape* problem_shapes,
GemmCoord cta_shape, GemmCoord cta_shape,
GemmCoord cluster_shape,
FastDivmodU64Pow2 const& divmod_cluster_shape_major, FastDivmodU64Pow2 const& divmod_cluster_shape_major,
FastDivmodU64Pow2 const& divmod_cluster_shape_minor, FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
FastDivmodU64 const& divmod_cta_shape_m,
FastDivmodU64 const& divmod_cta_shape_n,
int32_t log_swizzle_size, int32_t log_swizzle_size,
RasterOrder raster_order) { RasterOrder raster_order) {
bool valid_tile = true; bool valid_tile = true;
int cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m())); uint64_t ctas_along_m, ctas_along_n;
int cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n())); if (is_tuple<decltype(cute::shape<0>(problem_shapes[group_info.group_idx]))>::value ||
is_tuple<decltype(cute::shape<1>(problem_shapes[group_info.group_idx]))>::value) {
ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group_idx]), cta_shape.m()));
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group_idx]), cta_shape.n()));
}
else {
ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_m.divisor - 1);
ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_n.divisor - 1);
}
auto problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m());
auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n());
group_info.total_tiles = problem_blocks_m * problem_blocks_n;
while (group_info.start_linear_idx + group_info.total_tiles <= linear_idx) {
group_info.group_idx++;
if (group_info.group_idx >= total_problem_groups)
return WorkTileInfo::invalid_work_tile();
while (group_info.start_linear_idx + group_info.total_tiles <= blk_per_grid_dim) {
group_info.group++;
group_info.start_linear_idx += group_info.total_tiles; group_info.start_linear_idx += group_info.total_tiles;
cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m())); if (is_tuple<decltype(cute::shape<0>(problem_shapes[group_info.group_idx]))>::value ||
cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n())); is_tuple<decltype(cute::shape<1>(problem_shapes[group_info.group_idx]))>::value) {
group_info.total_tiles = cta_m * cta_n; ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group_idx]), cta_shape.m()));
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group_idx]), cta_shape.n()));
}
else {
ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_m.divisor - 1);
ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_n.divisor - 1);
}
problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m());
problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n());
group_info.total_tiles = problem_blocks_m * problem_blocks_n;
} }
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim - group_info.start_linear_idx); uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx);
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
if (raster_order == RasterOrder::AlongN) { if (raster_order == RasterOrder::AlongN) {
@ -252,8 +348,13 @@ public:
offset = cluster_id & ((1 << log_swizzle_size) - 1); offset = cluster_id & ((1 << log_swizzle_size) - 1);
extra = cluster_id >> log_swizzle_size; extra = cluster_id >> log_swizzle_size;
uint64_t curr_group_cluster_blk_major, remainder; uint64_t curr_group_cluster_blk_major;
divmod_cluster_shape_major(curr_group_cluster_blk_major, remainder, cta_m); if (raster_order == RasterOrder::AlongN) {
curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_n);
}
else {
curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_m);
}
cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major;
cluster_idx_major = extra % curr_group_cluster_blk_major; cluster_idx_major = extra % curr_group_cluster_blk_major;
@ -265,61 +366,14 @@ public:
cluster_major_offset); cluster_major_offset);
if (raster_order == RasterOrder::AlongN) { if (raster_order == RasterOrder::AlongN) {
return {minor_work_idx, major_work_idx, group_info, valid_tile}; return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile};
} }
else { else {
return {major_work_idx, minor_work_idx, group_info, valid_tile}; return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile};
} }
} }
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_cta_shape_mnl(int groups, ProblemShape const* problem_shapes, BlockShape cta_shape, ClusterShape cluster_shape) {
uint32_t total_ctas = 0;
uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here
for (int group = 0; group < groups; group++) {
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group]), cute::shape<0>(cta_shape)));
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group]), cute::shape<1>(cta_shape)));
total_ctas += cta_m * cta_n;
}
return Params::get_tiled_cta_shape_mnl(
to_gemm_coord(cluster_shape),
total_ctas, cta_in_N_dim
);
}
// Given the inputs, computes the physical grid we should launch.
template<class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
GroupProblemShape problem_shapes,
BlockShape cta_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments,
bool truncate_by_problem_size=true) {
dim3 problem_blocks = get_tiled_cta_shape_mnl(
problem_shapes.groups(),
reinterpret_cast<ProblemShape const*>(problem_shapes.host_problem_shapes),
cta_shape, cluster_shape);
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order,
/* truncate_by_problem_size = */true
);
}
// Returns whether the block assigned this work should compute the epilogue for the corresponding // Returns whether the block assigned this work should compute the epilogue for the corresponding
// output tile. For the basic tile scheduler, this is always true. // output tile. For the basic tile scheduler, this is always true.
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE

View File

@ -1273,15 +1273,18 @@ struct PersistentTileSchedulerSm90GroupParams {
FastDivmodU64Pow2 divmod_cluster_shape_major_{}; FastDivmodU64Pow2 divmod_cluster_shape_major_{};
FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; FastDivmodU64Pow2 divmod_cluster_shape_minor_{};
FastDivmodU64 divmod_batch_{}; FastDivmodU64 divmod_cta_shape_m_{};
FastDivmodU64 divmod_cta_shape_n_{};
uint64_t blocks_per_problem_ = 0; uint64_t blocks_across_problem_ = 0;
bool pre_processed_problem_shapes = true;
int32_t log_swizzle_size_ = 0; int32_t log_swizzle_size_ = 0;
RasterOrder raster_order_ = RasterOrder::AlongN; RasterOrder raster_order_ = RasterOrder::AlongN;
int32_t groups_ = 0; int32_t groups_ = 0;
ProblemShape* problem_shapes_ = nullptr; ProblemShape* problem_shapes_ = nullptr;
GemmCoord cta_shape_; GemmCoord cta_shape_;
GemmCoord cluster_shape_;
// Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions.
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
@ -1291,6 +1294,7 @@ struct PersistentTileSchedulerSm90GroupParams {
dim3 problem_blocks, dim3 problem_blocks,
int32_t groups, int32_t groups,
ProblemShape* problem_shapes, ProblemShape* problem_shapes,
ProblemShape const* host_problem_shapes,
GemmCoord cta_shape, GemmCoord cta_shape,
GemmCoord cluster_shape, GemmCoord cluster_shape,
KernelHardwareInfo const& hw_info, KernelHardwareInfo const& hw_info,
@ -1317,11 +1321,12 @@ struct PersistentTileSchedulerSm90GroupParams {
groups_ = groups; groups_ = groups;
problem_shapes_ = problem_shapes; problem_shapes_ = problem_shapes;
cta_shape_ = cta_shape; cta_shape_ = cta_shape;
cluster_shape_ = cluster_shape;
blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks.z; blocks_across_problem_ = problem_blocks.x * problem_blocks.y * problem_blocks.z;
pre_processed_problem_shapes = (host_problem_shapes == nullptr) ? false : true;
log_swizzle_size_ = log_swizzle_size; log_swizzle_size_ = log_swizzle_size;
raster_order_ = raster_order; raster_order_ = raster_order;
divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n);
if (raster_order == RasterOrder::AlongN) { if (raster_order == RasterOrder::AlongN) {
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n()); divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n());
@ -1331,6 +1336,9 @@ struct PersistentTileSchedulerSm90GroupParams {
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m()); divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m());
divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n()); divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n());
} }
divmod_cta_shape_m_ = FastDivmodU64(cta_shape_.m());
divmod_cta_shape_n_ = FastDivmodU64(cta_shape_.n());
} }
// Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions. // Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions.
@ -1344,8 +1352,8 @@ struct PersistentTileSchedulerSm90GroupParams {
auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n(); auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n();
return { return {
static_cast<uint32_t>(problem_blocks_m), static_cast<uint32_t>(cta_m),
static_cast<uint32_t>(problem_blocks_n), static_cast<uint32_t>(cta_n),
static_cast<uint32_t>(1) // Only a single batch per group is currently supported static_cast<uint32_t>(1) // Only a single batch per group is currently supported
}; };
} }

80
include/cutlass/version.h Normal file
View File

@ -0,0 +1,80 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cstdint>
#include <string>
#define CUTLASS_MAJOR 3
#define CUTLASS_MINOR 4
#define CUTLASS_PATCH 1
#ifdef CUTLASS_VERSIONS_GENERATED
#include "cutlass/version_extended.h"
#else
#define CUTLASS_BUILD 0
#define CUTLASS_REVISION ""
#endif
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
namespace cutlass {
inline constexpr uint32_t getVersion() {
return CUTLASS_VERSION;
}
inline constexpr uint32_t getVersionMajor() {
return CUTLASS_MAJOR;
}
inline constexpr uint32_t getVersionMinor() {
return CUTLASS_MINOR;
}
inline constexpr uint32_t getVersionPatch() {
return CUTLASS_PATCH;
}
inline constexpr uint32_t getVersionBuild() {
return CUTLASS_BUILD + 0;
}
inline std::string getVersionString() {
std::string version = "@CUTLASS_VERSION@";
if (getVersionBuild()) {
version += "." + std::to_string(getVersionBuild());
}
return version;
}
inline std::string getGitRevision() {
return "@CUTLASS_REVISION@";
}
} // namespace cutlass

View File

@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "nvidia-cutlass" name = "nvidia-cutlass"
version = "3.4.0.0" version = "3.4.1.0"
description = "CUTLASS" description = "CUTLASS"
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
license = {file = "LICENSE.txt"} license = {text = "BSD-3-Clause"}
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License", "License :: OSI Approved :: BSD License",

View File

@ -40,7 +40,7 @@ import cutlass_library
def _cuda_install_path_from_nvcc() -> str: def _cuda_install_path_from_nvcc() -> str:
import subprocess import subprocess
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC # Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
result = subprocess.run(['which', 'nvcc'], capture_output=True) result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True)
if result.returncode != 0: if result.returncode != 0:
raise Exception(f'Unable to find nvcc via `which` utility.') raise Exception(f'Unable to find nvcc via `which` utility.')
@ -121,7 +121,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc()) this._option_registry = OptionRegistry(device_cc())
return this._option_registry return this._option_registry
this.__version__ = '3.4.0' this.__version__ = '3.4.1'
from cutlass.backend import create_memory_pool from cutlass.backend import create_memory_pool
from cutlass.emit.pytorch import pytorch from cutlass.emit.pytorch import pytorch
@ -169,7 +169,7 @@ def initialize_cuda_context():
raise Exception("No CUDA devices found") raise Exception("No CUDA devices found")
device_id = 0 device_id = 0
this._device_id = device_id this._device_id = int(device_id)
def device_id() -> int: def device_id() -> int:

View File

@ -213,8 +213,12 @@ def get_mainloop_arguments_3x(
return _MainloopArgumentsTma return _MainloopArgumentsTma
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args): def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"):
_EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt
else:
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
if hasattr(epilogue_functor, "visitor"): if hasattr(epilogue_functor, "visitor"):
class _EpilogueArguments(ctypes.Structure): class _EpilogueArguments(ctypes.Structure):
_fields_ = [ _fields_ = [

View File

@ -157,19 +157,41 @@ class LinearCombination(EpilogueFunctorBase):
c_element_epilogue = dtype2ctype[self.element_epilogue] c_element_epilogue = dtype2ctype[self.element_epilogue]
element_epilogue = self.element_epilogue element_epilogue = self.element_epilogue
class _EpilogueOutputOpParams(ctypes.Structure): class _EpilogueOutputOpParamsEVT(ctypes.Structure):
"""
Epilogue params when using the default linear combination of EVT, which
does not currently use {alpha,beta}_ptr_array
"""
_fields_ = [ _fields_ = [
("alpha", c_element_epilogue), ("alpha", c_element_epilogue),
("beta", c_element_epilogue), ("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p), ("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p) ("beta_ptr", ctypes.c_void_p),
] ]
def __init__(self, alpha, beta, *args) -> None: def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue) self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue) self.beta = to_ctype_value(beta, element_epilogue)
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
("alpha_ptr_array", ctypes.c_void_p),
("beta_ptr_array", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
def to_evt_params(self) -> _EpilogueOutputOpParamsEVT:
return _EpilogueOutputOpParamsEVT(self.alpha, self.beta)
self.epilogue_type = _EpilogueOutputOpParams self.epilogue_type = _EpilogueOutputOpParams
self.epilogue_type_evt = _EpilogueOutputOpParamsEVT
def emit(self): def emit(self):
return super().emit(self.tag, self.template_arguments) return super().emit(self.tag, self.template_arguments)

View File

@ -241,10 +241,10 @@ class EVTFrontendBase:
:param name: the name of the graph :param name: the name of the graph
""" """
drawer = EVTGraphDrawer(self.dag_ir, name) drawer = EVTGraphDrawer(self.dag_ir, name)
if drawer.dot_available: try:
for name, graph in drawer.get_dot_graph(): for name, graph in drawer.get_dot_graph():
graph.write_svg(f"./{name}.svg") graph.write_svg(f"./{name}.svg")
else: except:
raise RuntimeError( raise RuntimeError(
"'dot' is not found in path. GraphDrawer is disabled. " "'dot' is not found in path. GraphDrawer is disabled. "
"Please install it with 'sudo apt-get install graphviz'." "Please install it with 'sudo apt-get install graphviz'."

View File

@ -61,22 +61,6 @@ class EVTGraphDrawer:
self._dot_graphs = {} self._dot_graphs = {}
self._dot_graphs[name] = self._to_dot(graph, name) self._dot_graphs[name] = self._to_dot(graph, name)
self.dot_available = self._check_dot_availability()
def _check_dot_availability(self):
"""
Check if graphviz is installed
"""
try:
# Run the 'dot' command and capture its output
result = subprocess.run(
["dot", "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Check if the command was successful and the output contains version information
if result.returncode == 0 and "dot - graphviz" in result.stderr:
return True
except FileNotFoundError:
pass
return False
def _get_node_style(self, node): def _get_node_style(self, node):
template = { template = {

View File

@ -325,7 +325,7 @@ class GemmArguments2x(ArgumentBase):
def initialize(self): def initialize(self):
launch_config = self.operation.rt_module.plan(self) launch_config = self.operation.rt_module.plan(self)
# Get the host and evice workspace # Get the host and device workspace
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
if device_workspace_size > 0: if device_workspace_size > 0:
@ -512,6 +512,18 @@ class GemmArguments3x(GemmArguments2x):
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
def get_arguments(self): def get_arguments(self):
mainloop_args = get_mainloop_arguments_3x(
self.operation.tile_description.kernel_schedule,
self.operation.A.element,
self.operation.B.element,
self.operation.A.alignment,
self.operation.B.alignment
)
scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count) problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
if self.batch_count > 1: if self.batch_count > 1:
@ -539,9 +551,12 @@ class GemmArguments3x(GemmArguments2x):
) )
# Set of mainloop arguments needed for this kernel # Set of mainloop arguments needed for this kernel
mainloop = self.operation.rt_module.mainloop_args.from_generic_mainloop_args(generic_args) mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
epilogue = self.operation.rt_module.epilogue_args( if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
self.output_op = self.output_op.to_evt_params()
epilogue = epilogue_args(
self.output_op, self.output_op,
int(self.ptr_C), int(self.ptr_C),
stride_C, stride_C,
@ -550,15 +565,15 @@ class GemmArguments3x(GemmArguments2x):
) )
# Set hardware info # Set hardware info
hw_info = self.operation.rt_module.hw_info(0, device_sm_count()) hw_info_ = hw_info(0, device_sm_count())
self.arguments = self.operation.argument_type( self.arguments = argument_type(
int(self.gemm_mode), int(self.gemm_mode),
problem_size_, problem_size_,
mainloop, mainloop,
epilogue, epilogue,
hw_info, hw_info_,
self.operation.rt_module.scheduler_args scheduler_args
) )
return self.arguments return self.arguments
@ -1119,6 +1134,10 @@ extern "C" {
using GemmType = ${operation_name}_base; using GemmType = ${operation_name}_base;
bool ${operation_name}_uses_default_epilogue() {
return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
}
// Get the workspace size // Get the workspace size
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) { uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
return GemmType::get_workspace_size(*argument); return GemmType::get_workspace_size(*argument);
@ -1163,19 +1182,10 @@ extern "C" {
"get_grid_shape": dim3_, "get_grid_shape": dim3_,
"get_block_shape": dim3_, "get_block_shape": dim3_,
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64, "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
"get_kernel_workspace_size": ctypes.c_uint64 "get_kernel_workspace_size": ctypes.c_uint64,
"uses_default_epilogue": ctypes.c_bool,
} }
self.emitter = EmitGemmUniversalInstance3x("_type") self.emitter = EmitGemmUniversalInstance3x("_type")
self.mainloop_args = get_mainloop_arguments_3x(
operation.tile_description.kernel_schedule,
operation.A.element,
operation.B.element,
operation.A.alignment,
operation.B.alignment
)
self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler)
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(
self.mainloop_args, operation.epilogue_functor, self.scheduler_args)
def get_device_workspace_size(self, arguments: GemmArguments3x): def get_device_workspace_size(self, arguments: GemmArguments3x):
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments())) return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup(): def perform_setup():
setup( setup(
name='cutlass_library', name='cutlass_library',
version='3.4.0', version='3.4.1',
description='CUTLASS library generation scripts', description='CUTLASS library generation scripts',
packages=['cutlass_library'] packages=['cutlass_library']
) )

View File

@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup(): def perform_setup():
setup( setup(
name='pycute', name='pycute',
version='3.4.0', version='3.4.1',
description='Python implementation of CuTe', description='Python implementation of CuTe',
packages=['pycute'], packages=['pycute'],
) )

View File

@ -1,5 +1,5 @@
/*************************************************************************************************** /***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause * SPDX-License-Identifier: BSD-3-Clause
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without
@ -56,9 +56,6 @@
#include "gemm_testbed_3x_evt.hpp" #include "gemm_testbed_3x_evt.hpp"
#include "sm90_evt_operations.hpp" #include "sm90_evt_operations.hpp"
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute; using namespace cute;
@ -132,7 +129,7 @@ bool testEVTAuxStoreWithoutD() {
D_block.reset(m * n); D_block.reset(m * n);
aux_store_D_block.reset(m * n); aux_store_D_block.reset(m * n);
Gemm gemm_op_base; Gemm gemm_op_base;
auto stride_A = cutlass::make_cute_packed_stride( auto stride_A = cutlass::make_cute_packed_stride(
typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{})); typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{}));
auto stride_B = cutlass::make_cute_packed_stride( auto stride_B = cutlass::make_cute_packed_stride(
@ -141,7 +138,7 @@ bool testEVTAuxStoreWithoutD() {
typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{})); typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{}));
auto stride_D = cutlass::make_cute_packed_stride( auto stride_D = cutlass::make_cute_packed_stride(
typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{})); typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{}));
auto arguments_base = typename Gemm::Arguments { auto arguments_base = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGemm, cutlass::gemm::GemmUniversalMode::kGemm,
problem_size, problem_size,
@ -178,12 +175,12 @@ bool testEVTAuxStoreWithoutD() {
/*hw_info=*/{}, /*hw_info=*/{},
/*scheduler_args=*/{} /*scheduler_args=*/{}
}; };
constexpr float beta [[maybe_unused]] = 1.0; constexpr float beta [[maybe_unused]] = 1.0;
constexpr float alpha [[maybe_unused]] = 1.0; constexpr float alpha [[maybe_unused]] = 1.0;
using ElementC = typename GemmWithoutD::ElementC; using ElementC = typename GemmWithoutD::ElementC;
if constexpr (not has_c) { if constexpr (not has_c) {
arguments_base.epilogue.thread = { arguments_base.epilogue.thread = {
// binary op : alpha * acc // binary op : alpha * acc
@ -282,7 +279,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
>; >;
using namespace cutlass::epilogue::fusion; using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -324,10 +321,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr); return *(GemmKernel *)(nullptr);
}; };
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{})); using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{})); using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>; using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -352,7 +349,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
>; >;
using namespace cutlass::epilogue::fusion; using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -394,10 +391,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr); return *(GemmKernel *)(nullptr);
}; };
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{})); using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{})); using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>; using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -467,7 +464,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{})); using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{})); using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>; using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -492,7 +489,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
>; >;
using namespace cutlass::epilogue::fusion; using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -534,10 +531,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr); return *(GemmKernel *)(nullptr);
}; };
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{})); using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{})); using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>; using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -562,7 +559,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
>; >;
using namespace cutlass::epilogue::fusion; using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -604,10 +601,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr); return *(GemmKernel *)(nullptr);
}; };
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{})); using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{})); using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>; using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -677,7 +674,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{})); using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{})); using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>; using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;

View File

@ -35,7 +35,6 @@
#pragma once #pragma once
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "cute/stride.hpp"
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////