Updates for CUTLASS 3.4.1 (#1346)
* Updates for CUTLASS 3.4.1 * minor epi change
This commit is contained in:
parent
47a3ebbea9
commit
bbe579a9e3
@ -1,5 +1,11 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
|
||||
## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
|
||||
|
||||
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMMs](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Updates and bugfixes from the community (thanks!).
|
||||
|
||||
## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12)
|
||||
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
|
||||
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
|
||||
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
|
||||
@ -8,7 +14,6 @@
|
||||
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
|
||||
* Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
|
||||
|
||||
|
||||
## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31)
|
||||
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
|
||||
* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}.
|
||||
|
||||
@ -40,7 +40,25 @@ endif()
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
|
||||
|
||||
project(CUTLASS VERSION 3.4.0 LANGUAGES CXX)
|
||||
# To reduce duplicate version locations, parse the version out of the
|
||||
# main versions.h file and reuse it here.
|
||||
|
||||
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/include/cutlass/version.h VERSION_FILE_CONTENTS)
|
||||
string(REGEX MATCH "#define CUTLASS_MAJOR ([0-9]+)" _CUTLASS_VERSION_MAJOR "${VERSION_FILE_CONTENTS}")
|
||||
set(_CUTLASS_VERSION_MAJOR ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define CUTLASS_MINOR ([0-9]+)" _CUTLASS_VERSION_MINOR "${VERSION_FILE_CONTENTS}")
|
||||
set(_CUTLASS_VERSION_MINOR ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define CUTLASS_PATCH ([0-9]+)" _CUTLASS_VERSION_PATCH "${VERSION_FILE_CONTENTS}")
|
||||
set(_CUTLASS_VERSION_PATCH ${CMAKE_MATCH_1})
|
||||
|
||||
message(STATUS "CUTLASS ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH}")
|
||||
|
||||
## CUTLASS PROJECT #############################################################
|
||||
|
||||
project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH} LANGUAGES CXX)
|
||||
|
||||
################################################################################
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 11.3)
|
||||
@ -178,6 +196,9 @@ if(WIN32)
|
||||
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_VERSIONS_GENERATED")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_VERSIONS_GENERATED")
|
||||
|
||||
if (WIN32)
|
||||
# Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors.
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3)
|
||||
@ -589,8 +610,8 @@ if (NOT DEFINED CUTLASS_REVISION)
|
||||
endif()
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version_extended.h.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version_extended.h
|
||||
@ONLY)
|
||||
|
||||
target_include_directories(
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
|
||||
## 2023
|
||||
|
||||
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi and Jay Shah. _arXiv_, December 2023.
|
||||
- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023.
|
||||
|
||||
|
||||
- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023.
|
||||
|
||||
|
||||
13
README.md
13
README.md
@ -2,7 +2,7 @@
|
||||
|
||||
# CUTLASS 3.4
|
||||
|
||||
_CUTLASS 3.4 - January 2024_
|
||||
_CUTLASS 3.4 - February 2024_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
|
||||
@ -43,13 +43,18 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im
|
||||
|
||||
# What's New in CUTLASS 3.4
|
||||
|
||||
CUTLASS 3.4.1 is an update to CUTLASS adding:
|
||||
- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side.
|
||||
- Improvements for Hopper [Group-GEMM](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMM](/examples/56_hopper_ptr_array_batched_gemm).
|
||||
- Updates and bugfixes from the community (thanks!).
|
||||
|
||||
CUTLASS 3.4.0 is an update to CUTLASS adding:
|
||||
|
||||
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
|
||||
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
|
||||
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
|
||||
- [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now.
|
||||
- Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
|
||||
- Improvements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.
|
||||
- Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved.
|
||||
|
||||
Minimum requirements:
|
||||
@ -93,8 +98,8 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++17 host compiler and
|
||||
performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2 and CUDA 12.3.1
|
||||
performs best when built with the [**CUDA 12.3.2 Toolkit**](https://developer.nvidia.com/cuda-downloads).
|
||||
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2.
|
||||
|
||||
## Operating Systems
|
||||
We have tested the following environments.
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@
|
||||
#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@
|
||||
#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@
|
||||
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
inline uint32_t getVersion() {
|
||||
return CUTLASS_VERSION;
|
||||
}
|
||||
inline uint32_t getVersionMajor() {
|
||||
return CUTLASS_MAJOR;
|
||||
}
|
||||
inline uint32_t getVersionMinor() {
|
||||
return CUTLASS_MINOR;
|
||||
}
|
||||
inline uint32_t getVersionPatch() {
|
||||
return CUTLASS_PATCH;
|
||||
}
|
||||
inline uint32_t getVersionBuild() {
|
||||
return CUTLASS_BUILD + 0;
|
||||
}
|
||||
inline std::string getVersionString() {
|
||||
std::string version = "@CUTLASS_VERSION@";
|
||||
if (getVersionBuild()) {
|
||||
version += "." + std::to_string(getVersionBuild());
|
||||
}
|
||||
return version;
|
||||
}
|
||||
inline std::string getGitRevision() {
|
||||
return "@CUTLASS_REVISION@";
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
34
cmake/version_extended.h.in
Normal file
34
cmake/version_extended.h.in
Normal file
@ -0,0 +1,34 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
|
||||
#define CUTLASS_REVISION "@CUTLASS_REVISION@"
|
||||
@ -31,4 +31,5 @@
|
||||
cutlass_example_add_executable(
|
||||
02_dump_reg_shmem
|
||||
dump_reg_shmem.cu
|
||||
DISABLE_TESTS ON
|
||||
)
|
||||
|
||||
@ -70,7 +70,7 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
@ -98,8 +98,8 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelArrayTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedArray; // Epilogue to launch
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
@ -169,7 +169,7 @@ cutlass::DeviceAllocation<const typename Gemm::ElementC *> ptr_C;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_D;
|
||||
cutlass::DeviceAllocation<typename Gemm::EpilogueOutputOp::ElementOutput *> ptr_ref_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
@ -245,7 +245,7 @@ struct Result
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -468,7 +468,7 @@ int run(Options &options)
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -510,7 +510,7 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
#endif
|
||||
|
||||
|
||||
@ -27,17 +27,17 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Note that we set --iterations=0 for all tests below to disable the performance benchmarking.
|
||||
# Only the correctness check will be run by these commands.
|
||||
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes
|
||||
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes
|
||||
|
||||
set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=0) # Square problem sizes
|
||||
set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=0) # Square problem sizes
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=1) # Default problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes
|
||||
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Default problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=0) # Default problem sizes
|
||||
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Default problem sizes w/ Epilogue Op test
|
||||
set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test
|
||||
|
||||
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=0) # Small-k problem sizes
|
||||
set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=0) # Small-k problem sizes
|
||||
set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes
|
||||
set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=1) # Small-k problem sizes
|
||||
|
||||
cutlass_example_add_executable(
|
||||
56_hopper_ptr_array_batched_gemm
|
||||
@ -47,6 +47,8 @@ cutlass_example_add_executable(
|
||||
TEST_SQUARE_LARGE_BATCH
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_BATCH
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_BATCH
|
||||
TEST_SMALLK
|
||||
TEST_SMALLK_LARGE_BATCH
|
||||
)
|
||||
|
||||
@ -44,6 +44,7 @@
|
||||
|
||||
The above example command makes all 10 groups to be sized at the given m, n, k sizes.
|
||||
Skipping any of the problem dimensions randomizes it across the different groups.
|
||||
Same applies for alpha and beta values that are randomized across the different groups.
|
||||
|
||||
To run this example for a set of problems using the benchmark option:
|
||||
|
||||
@ -62,6 +63,7 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <float.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@ -91,9 +93,9 @@ using namespace cute;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
|
||||
using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand
|
||||
using ElementC = float; // Element type for C and D matrix operands
|
||||
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
@ -101,40 +103,40 @@ using ElementC = float; // Element type
|
||||
|
||||
// A matrix configuration
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// B matrix configuration
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// C/D matrix configuration
|
||||
using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes)
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size
|
||||
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
using KernelSchedule = cutlass::gemm::KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedGroup; // Epilogue to launch
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
ElementC, LayoutC *, AlignmentC,
|
||||
EpilogueSchedule
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
ElementA, LayoutA, AlignmentA,
|
||||
ElementB, LayoutB, AlignmentB,
|
||||
ElementA, LayoutA *, AlignmentA,
|
||||
ElementB, LayoutB *, AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
@ -161,10 +163,10 @@ using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementAccumulator,
|
||||
ElementAccumulator>;
|
||||
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;
|
||||
|
||||
// Host-side allocations
|
||||
std::vector<int64_t> offset_A;
|
||||
@ -177,6 +179,9 @@ std::vector<StrideB> stride_B_host;
|
||||
std::vector<StrideC> stride_C_host;
|
||||
std::vector<StrideD> stride_D_host;
|
||||
|
||||
std::vector<ElementAccumulator> alpha_host;
|
||||
std::vector<ElementAccumulator> beta_host;
|
||||
|
||||
// Device-side allocations
|
||||
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> problem_sizes;
|
||||
|
||||
@ -197,7 +202,13 @@ cutlass::DeviceAllocation<StrideB> stride_B;
|
||||
cutlass::DeviceAllocation<StrideC> stride_C;
|
||||
cutlass::DeviceAllocation<StrideD> stride_D;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
// Note, this is an array of pointers to alpha and beta scaling values per group
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> alpha_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator*> beta_device;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_alpha;
|
||||
cutlass::DeviceAllocation<ElementAccumulator> block_beta;
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Testbed utility types
|
||||
@ -208,8 +219,8 @@ struct Options {
|
||||
|
||||
bool help = false;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
float alpha = FLT_MAX;
|
||||
float beta = FLT_MAX;
|
||||
int iterations = 10;
|
||||
int m = 1024, n = 2048, k = 512, groups = 10;
|
||||
std::string benchmark_path;
|
||||
@ -230,8 +241,8 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("n", n);
|
||||
cmd.get_cmd_line_argument("k", k);
|
||||
cmd.get_cmd_line_argument("groups", groups);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("beta", beta, FLT_MAX);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("benchmark", benchmark_path);
|
||||
|
||||
@ -248,10 +259,7 @@ struct Options {
|
||||
}
|
||||
|
||||
void randomize_problems(cutlass::CommandLine &cmd) {
|
||||
int cmd_line_m = -1;
|
||||
int cmd_line_n = -1;
|
||||
int cmd_line_k = -1;
|
||||
|
||||
int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1;
|
||||
cmd.get_cmd_line_argument("m", cmd_line_m);
|
||||
cmd.get_cmd_line_argument("n", cmd_line_n);
|
||||
cmd.get_cmd_line_argument("k", cmd_line_k);
|
||||
@ -259,19 +267,15 @@ struct Options {
|
||||
problem_sizes_host.reserve(groups);
|
||||
|
||||
for (int i = groups; i > 0; i--) {
|
||||
|
||||
int m = cmd_line_m;
|
||||
int n = cmd_line_n;
|
||||
int k = cmd_line_k;
|
||||
|
||||
if (m < 1) {
|
||||
m = ((rand() % 512) + 1);
|
||||
}
|
||||
|
||||
if (n < 1) {
|
||||
n = ((rand() % 512) + 1);
|
||||
}
|
||||
|
||||
if (k < 1) {
|
||||
k = alignment * ((rand() % 64) + 1);
|
||||
}
|
||||
@ -317,6 +321,7 @@ struct Options {
|
||||
problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()});
|
||||
}
|
||||
}
|
||||
groups = static_cast<int>(problem_sizes_host.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -351,7 +356,9 @@ struct Options {
|
||||
uint64_t fmas = uint64_t();
|
||||
|
||||
for (auto const & problem : problem_sizes_host) {
|
||||
fmas += cute::size(problem);
|
||||
fmas += static_cast<uint64_t>(get<0>(problem)) *
|
||||
static_cast<uint64_t>(get<1>(problem)) *
|
||||
static_cast<uint64_t>(get<2>(problem));
|
||||
}
|
||||
// Two flops per multiply-add
|
||||
uint64_t flop = uint64_t(2) * uint64_t(fmas);
|
||||
@ -370,7 +377,7 @@ struct Result
|
||||
bool passed = false;
|
||||
};
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM setup and evaluation
|
||||
@ -435,6 +442,7 @@ void allocate(const Options &options) {
|
||||
stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})));
|
||||
stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})));
|
||||
stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})));
|
||||
|
||||
}
|
||||
|
||||
block_A.reset(total_elements_A);
|
||||
@ -442,6 +450,8 @@ void allocate(const Options &options) {
|
||||
block_C.reset(total_elements_C);
|
||||
block_D.reset(total_elements_D);
|
||||
block_ref_D.reset(total_elements_D);
|
||||
block_alpha.reset(options.groups);
|
||||
block_beta.reset(options.groups);
|
||||
}
|
||||
|
||||
/// Initialize operands to be used in the GEMM and reference GEMM
|
||||
@ -460,12 +470,18 @@ void initialize(const Options &options) {
|
||||
std::vector<ElementB *> ptr_B_host(options.groups);
|
||||
std::vector<ElementC *> ptr_C_host(options.groups);
|
||||
std::vector<ElementC *> ptr_D_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
|
||||
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
|
||||
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
ptr_A_host.at(i) = block_A.get() + offset_A.at(i);
|
||||
ptr_B_host.at(i) = block_B.get() + offset_B.at(i);
|
||||
ptr_C_host.at(i) = block_C.get() + offset_C.at(i);
|
||||
ptr_D_host.at(i) = block_D.get() + offset_D.at(i);
|
||||
alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast<ElementAccumulator>((rand() % 5) + 1) : options.alpha);
|
||||
beta_host.push_back((options.beta == FLT_MAX) ? static_cast<ElementAccumulator>(rand() % 5) : options.beta);
|
||||
ptr_alpha_host.at(i) = block_alpha.get() + i;
|
||||
ptr_beta_host.at(i) = block_beta.get() + i;
|
||||
}
|
||||
|
||||
ptr_A.reset(options.groups);
|
||||
@ -492,13 +508,20 @@ void initialize(const Options &options) {
|
||||
stride_D.reset(options.groups);
|
||||
stride_D.copy_from_host(stride_D_host.data());
|
||||
|
||||
alpha_device.reset(options.groups);
|
||||
alpha_device.copy_from_host(ptr_alpha_host.data());
|
||||
beta_device.reset(options.groups);
|
||||
beta_device.copy_from_host(ptr_beta_host.data());
|
||||
|
||||
initialize_block(block_A, seed + 2023);
|
||||
initialize_block(block_B, seed + 2022);
|
||||
initialize_block(block_C, seed + 2021);
|
||||
block_alpha.copy_from_host(alpha_host.data());
|
||||
block_beta.copy_from_host(beta_host.data());
|
||||
}
|
||||
|
||||
/// Populates a Gemm::Arguments structure from the given commandline options
|
||||
typename Gemm::Arguments args_from_options(const Options &options)
|
||||
typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
|
||||
@ -506,13 +529,36 @@ typename Gemm::Arguments args_from_options(const Options &options)
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{{options.alpha, options.beta}, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
typename Gemm::EpilogueOutputOp::Params params;
|
||||
if (options.alpha != FLT_MAX && options.beta != FLT_MAX) {
|
||||
// If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches.
|
||||
params = typename Gemm::EpilogueOutputOp::Params(
|
||||
ElementAccumulator(options.alpha), ElementAccumulator(options.beta));
|
||||
}
|
||||
else {
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups.
|
||||
params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get());
|
||||
}
|
||||
|
||||
typename Gemm::Arguments arguments;
|
||||
if (host_problem_shapes_available) {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), options.problem_sizes_host.data()},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
arguments = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{options.groups, problem_sizes.get(), nullptr},
|
||||
{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()},
|
||||
{params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
return arguments;
|
||||
}
|
||||
@ -539,10 +585,10 @@ bool verify(const Options &options) {
|
||||
// Launch device reference gemm kernel
|
||||
gemm_reference(
|
||||
{M, N, K},
|
||||
ElementAccumulator(options.alpha),
|
||||
ElementAccumulator(alpha_host.at(i)),
|
||||
ref_A,
|
||||
ref_B,
|
||||
ElementAccumulator(options.beta),
|
||||
ElementAccumulator(beta_host.at(i)),
|
||||
ref_C,
|
||||
ref_D);
|
||||
|
||||
@ -560,7 +606,7 @@ bool verify(const Options &options) {
|
||||
|
||||
/// Execute a given example GEMM computation
|
||||
template <typename Gemm>
|
||||
int run(Options &options)
|
||||
int run(Options &options, bool host_problem_shapes_available = true)
|
||||
{
|
||||
allocate(options);
|
||||
initialize(options);
|
||||
@ -569,7 +615,7 @@ int run(Options &options)
|
||||
Gemm gemm;
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
|
||||
auto arguments = args_from_options(options);
|
||||
auto arguments = args_from_options(options, host_problem_shapes_available);
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
@ -612,12 +658,12 @@ int run(Options &options)
|
||||
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
|
||||
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host);
|
||||
|
||||
std::cout << " Problem Sizes: " << std::endl;
|
||||
for (auto const & problem : options.problem_sizes_host) {
|
||||
std::cout << " " << problem << std::endl;
|
||||
std::cout << " Problem Sizes, Alpha, Beta " << std::endl;
|
||||
for (int32_t i = 0; i < options.groups; ++i) {
|
||||
std::cout << " " << options.problem_sizes_host.at(i);
|
||||
std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl;
|
||||
}
|
||||
std::cout << " Groups : " << options.groups << std::endl;
|
||||
std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl;
|
||||
std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl;
|
||||
std::cout << " GFLOPS : " << result.gflops << std::endl;
|
||||
}
|
||||
@ -625,7 +671,7 @@ int run(Options &options)
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -667,8 +713,9 @@ int main(int argc, char const **args) {
|
||||
// Evaluate CUTLASS kernels
|
||||
//
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
run<Gemm>(options);
|
||||
run<Gemm>(options, false /*host_problem_shapes_available*/);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@ -35,9 +35,15 @@ set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0)
|
||||
set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes
|
||||
set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes
|
||||
|
||||
set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes
|
||||
set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes
|
||||
|
||||
set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes
|
||||
set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes
|
||||
|
||||
set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes
|
||||
set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes
|
||||
|
||||
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
|
||||
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes
|
||||
|
||||
@ -49,8 +55,12 @@ cutlass_example_add_executable(
|
||||
TEST_RANDOM_LARGE_GROUP
|
||||
TEST_EPILOGUE
|
||||
TEST_EPILOGUE_LARGE_GROUP
|
||||
TEST_EPILOGUE_OP
|
||||
TEST_EPILOGUE_OP_LARGE_GROUP
|
||||
TEST_FIXED
|
||||
TEST_FIXED_LARGE_GROUP
|
||||
TEST_SMALL
|
||||
TEST_SMALL_LARGE_GROUP
|
||||
TEST_RANDOM_PERF
|
||||
TEST_RANDOM_PERF_LARGE_GROUP
|
||||
)
|
||||
|
||||
@ -265,7 +265,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;"
|
||||
:: "l"(smem_int64_desc), "r"(prob_shape[2]));
|
||||
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
|
||||
// Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1
|
||||
asm volatile (
|
||||
"tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;"
|
||||
:: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4));
|
||||
|
||||
@ -391,7 +391,7 @@ struct TiledMMA : MMA_Atom
|
||||
} else {
|
||||
return cute::max(core_size, perm_size);
|
||||
}
|
||||
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
|
||||
@ -125,6 +125,9 @@ using CUTE_STL_NAMESPACE::invoke_result_t;
|
||||
using CUTE_STL_NAMESPACE::common_type;
|
||||
using CUTE_STL_NAMESPACE::common_type_t;
|
||||
|
||||
using CUTE_STL_NAMESPACE::remove_pointer;
|
||||
using CUTE_STL_NAMESPACE::remove_pointer_t;
|
||||
|
||||
// <utility>
|
||||
using CUTE_STL_NAMESPACE::declval;
|
||||
|
||||
|
||||
@ -64,6 +64,10 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3)))
|
||||
#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -80,6 +80,40 @@ struct TagToStrideB<layout::ColumnMajor> {
|
||||
using tag = layout::ColumnMajor;
|
||||
};
|
||||
|
||||
// For each cutlass::layout *, provides its corresponding cute stride types, 64b by default
|
||||
// Used by pointer array and grouped gemm
|
||||
// Maps to modes [M, K, L]
|
||||
template <>
|
||||
struct TagToStrideA<layout::RowMajor *> {
|
||||
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using type = UnderlyingType*;
|
||||
using tag = layout::RowMajor;
|
||||
};
|
||||
|
||||
// Maps to modes [M, K, L]
|
||||
template <>
|
||||
struct TagToStrideA<layout::ColumnMajor *> {
|
||||
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
||||
using type = UnderlyingType*;
|
||||
using tag = layout::ColumnMajor;
|
||||
};
|
||||
|
||||
// Maps to modes [N, K, L]
|
||||
template <>
|
||||
struct TagToStrideB<layout::RowMajor *> {
|
||||
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
||||
using type = UnderlyingType*;
|
||||
using tag = layout::RowMajor;
|
||||
};
|
||||
|
||||
// Maps to modes [N, K, L]
|
||||
template <>
|
||||
struct TagToStrideB<layout::ColumnMajor *> {
|
||||
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using type = UnderlyingType*;
|
||||
using tag = layout::ColumnMajor;
|
||||
};
|
||||
|
||||
// Maps to modes [M, N, L]
|
||||
template <class LayoutTag>
|
||||
struct TagToStrideC : TagToStrideA<LayoutTag> { };
|
||||
@ -101,7 +135,7 @@ template<int ModeIndex, class Stride>
|
||||
constexpr bool
|
||||
is_major(Stride = {}) {
|
||||
// Account for stride types with and without batch mode and batch modes with static zero stride
|
||||
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(Stride{})))>::value;
|
||||
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(cute::remove_pointer_t<Stride>{})))>::value;
|
||||
}
|
||||
|
||||
// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices
|
||||
|
||||
@ -268,7 +268,7 @@ struct Sm90TmaBuilderImpl {
|
||||
// Passing void C disables source load + smem allocation
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
|
||||
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
|
||||
|
||||
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
@ -434,8 +434,7 @@ struct CollectiveBuilder<
|
||||
Schedule,
|
||||
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
|
||||
cute::enable_if_t<cute::is_same_v<Schedule, NoSmemWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, NoSmemWarpSpecializedArray> ||
|
||||
cute::is_same_v<Schedule, NoSmemWarpSpecializedGroup> >> {
|
||||
cute::is_same_v<Schedule, PtrArrayNoSmemWarpSpecialized> >> {
|
||||
|
||||
// Passing void C disables source load
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
||||
|
||||
@ -63,6 +63,7 @@ public:
|
||||
// Type Aliases
|
||||
//
|
||||
using EpilogueSchedule = EpilogueSchedule_;
|
||||
using DispatchPolicy = EpilogueSchedule_;
|
||||
|
||||
// derived types of output thread level operator
|
||||
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
||||
|
||||
@ -73,12 +73,10 @@ public:
|
||||
using ElementScalar = ElementCompute;
|
||||
using ElementC = typename ThreadEpilogueOp::ElementC;
|
||||
using StrideC = StrideC_;
|
||||
using UnderlyingStrideC = cute::remove_pointer_t<StrideC>;
|
||||
using ElementD = typename ThreadEpilogueOp::ElementD;
|
||||
using StrideD = StrideD_;
|
||||
using StridesC = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
|
||||
StrideC const*, StrideC>;
|
||||
using StridesD = cute::conditional_t<cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>,
|
||||
StrideD const*, StrideD>;
|
||||
using UnderlyingStrideD = cute::remove_pointer_t<StrideD>;
|
||||
|
||||
using GmemTiledCopyC = void;
|
||||
using GmemTiledCopyD = void;
|
||||
@ -86,10 +84,9 @@ public:
|
||||
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
||||
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
||||
|
||||
static_assert(cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup> ||
|
||||
cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedArray>, "Incompatible epilogue schedule.");
|
||||
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(cute::is_same_v<EpilogueSchedule, PtrArrayNoSmemWarpSpecialized>, "Incompatible epilogue schedule.");
|
||||
static_assert(rank(UnderlyingStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
static_assert(rank(UnderlyingStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
@ -97,9 +94,9 @@ public:
|
||||
struct Arguments {
|
||||
typename ThreadEpilogueOp::Params thread{};
|
||||
ElementC const** ptr_C = nullptr;
|
||||
StridesC dC{};
|
||||
StrideC dC{};
|
||||
ElementD** ptr_D = nullptr;
|
||||
StridesD dD{};
|
||||
StrideD dD{};
|
||||
};
|
||||
|
||||
// Device side epilogue params
|
||||
@ -140,12 +137,13 @@ public:
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
DefaultEpilogueArray(Params const& params_)
|
||||
: params(params_), epilogue_op(params_.thread) { }
|
||||
: params(params_) { }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool
|
||||
is_source_needed() {
|
||||
return epilogue_op.is_source_needed();
|
||||
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<
|
||||
@ -185,10 +183,23 @@ public:
|
||||
// Slice to get the tile this CTA is responsible for
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
|
||||
|
||||
StrideC stride_c;
|
||||
StrideD stride_d;
|
||||
if constexpr (cute::is_same_v<EpilogueSchedule, NoSmemWarpSpecializedGroup>) {
|
||||
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC[l_coord]);
|
||||
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
|
||||
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
|
||||
// we get the correct alpha/beta values for the current batch/group using group index.
|
||||
ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord);
|
||||
|
||||
if (epilogue_op.is_source_needed() && params.dC == nullptr) {
|
||||
// Beta value is non-zero while pointer to C is a nullptr
|
||||
assert(0);
|
||||
}
|
||||
|
||||
UnderlyingStrideC stride_c;
|
||||
UnderlyingStrideD stride_d;
|
||||
if constexpr (!cute::is_same_v<UnderlyingStrideC, StrideC>) {
|
||||
// If grouped gemm
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC[l_coord]);
|
||||
}
|
||||
stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD[l_coord]);
|
||||
}
|
||||
else {
|
||||
@ -197,7 +208,11 @@ public:
|
||||
}
|
||||
|
||||
// Represent the full output tensor
|
||||
Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C[l_coord]), make_shape(M,N,mock_L), stride_c); // (m,n,l)
|
||||
ElementC const* ptr_C_l = nullptr;
|
||||
if (epilogue_op.is_source_needed()) {
|
||||
ptr_C_l = params.ptr_C[l_coord];
|
||||
}
|
||||
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l)
|
||||
Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l)
|
||||
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
||||
@ -242,7 +257,6 @@ public:
|
||||
|
||||
private:
|
||||
Params params;
|
||||
ThreadEpilogueOp epilogue_op;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -148,12 +148,12 @@ private:
|
||||
|
||||
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
|
||||
constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{});
|
||||
|
||||
|
||||
using EmptyType = cute::tuple<>;
|
||||
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
|
||||
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
|
||||
array_aligned<SmemElementC, size(SmemLayoutC{}), SmemAlignmentC>,
|
||||
EmptyType>;
|
||||
using SmemDStorage = cute::conditional_t<is_destination_supported,
|
||||
using SmemDStorage = cute::conditional_t<is_destination_supported,
|
||||
array_aligned<SmemElementD, size(SmemLayoutD{}), SmemAlignmentD>,
|
||||
EmptyType>;
|
||||
|
||||
@ -189,6 +189,7 @@ public:
|
||||
|
||||
struct SharedStorage {
|
||||
using TensorStorage = TensorStorageImpl;
|
||||
|
||||
TensorStorage tensors;
|
||||
|
||||
using PipelineStorage = typename LoadPipeline::SharedStorage;
|
||||
@ -249,12 +250,12 @@ public:
|
||||
Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC));
|
||||
tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0));
|
||||
}
|
||||
|
||||
|
||||
typename Params::TMA_D tma_store_d;
|
||||
if constexpr (is_destination_supported) {
|
||||
Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
|
||||
tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutD{}(_,_,0));
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
|
||||
@ -385,13 +386,13 @@ public:
|
||||
|
||||
// Apply epilogue subtile, get matching smem tensor
|
||||
SmemElementC* ptr_sC = nullptr;
|
||||
|
||||
|
||||
if constexpr (is_source_supported) {
|
||||
if constexpr (ReuseSmemC) {
|
||||
ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D().data());
|
||||
} else {
|
||||
ptr_sC = shared_tensors.smem_C().data();
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
|
||||
@ -559,7 +560,7 @@ public:
|
||||
// Vectorized fragment view
|
||||
constexpr int FragmentSize = DispatchPolicy::FragmentSize;
|
||||
Tensor tRS_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(tRS_rAcc);
|
||||
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
|
||||
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
|
||||
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly");
|
||||
|
||||
// (t)hread-partition for (s)mem to (r)egister copy (tSR_)
|
||||
|
||||
@ -46,8 +46,7 @@ namespace cutlass::epilogue {
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct NoSmemWarpSpecialized {};
|
||||
struct NoSmemWarpSpecializedArray {};
|
||||
struct NoSmemWarpSpecializedGroup {};
|
||||
struct PtrArrayNoSmemWarpSpecialized {};
|
||||
struct TmaWarpSpecialized {};
|
||||
struct TmaWarpSpecializedCooperative {};
|
||||
// DEPRECATED schedules, will be removed in next release
|
||||
|
||||
@ -1247,6 +1247,7 @@ struct FusionCallbacks<
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
template <class FusionOpOrCallbacks, class = cute::void_t<>>
|
||||
struct get_element_aux {
|
||||
@ -1257,7 +1258,7 @@ template <class FusionOpOrCallbacks>
|
||||
struct get_element_aux<FusionOpOrCallbacks, cute::void_t<typename FusionOpOrCallbacks::ElementAux>> {
|
||||
using type = typename FusionOpOrCallbacks::ElementAux;
|
||||
};
|
||||
|
||||
|
||||
template <class NodeOp, class... ChildOps>
|
||||
struct get_element_aux<Sm90TreeVisitor<NodeOp, ChildOps...>, cute::void_t<>> {
|
||||
using type = typename get_element_aux<NodeOp>::type;
|
||||
@ -1270,7 +1271,7 @@ struct get_element_aux<FusionCallbacks<Ts...>, cute::void_t<typename FusionCallb
|
||||
public:
|
||||
using type = typename get_element_aux<Operation>::type;
|
||||
};
|
||||
}
|
||||
} // namespace cutlass:epilogue::fusion::detail
|
||||
|
||||
template <class Callbacks>
|
||||
using get_element_aux_t = typename detail::get_element_aux<Callbacks>::type;
|
||||
|
||||
@ -88,43 +88,72 @@ public:
|
||||
/// Host-constructable parameters structure
|
||||
struct Params
|
||||
{
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
ElementCompute const* const* alpha_ptr_array; ///< array of pointers to accumulator scalar per group/batch
|
||||
ElementCompute const* const* beta_ptr_array; ///< array of pointers to source scalar per group/batch
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
beta_ptr(nullptr),
|
||||
alpha_ptr_array(nullptr),
|
||||
beta_ptr_array(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta
|
||||
):
|
||||
alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { }
|
||||
alpha(alpha), beta(beta),
|
||||
alpha_ptr(nullptr), beta_ptr(nullptr),
|
||||
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha
|
||||
):
|
||||
alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { }
|
||||
alpha(alpha), beta(0),
|
||||
alpha_ptr(nullptr), beta_ptr(nullptr),
|
||||
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr
|
||||
):
|
||||
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { }
|
||||
alpha(0), beta(0),
|
||||
alpha_ptr(alpha_ptr), beta_ptr(beta_ptr),
|
||||
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr
|
||||
):
|
||||
alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { }
|
||||
alpha(0), beta(0),
|
||||
alpha_ptr(alpha_ptr), beta_ptr(nullptr),
|
||||
alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const* const* alpha_ptr_array,
|
||||
ElementCompute const* const* beta_ptr_array
|
||||
):
|
||||
alpha(0), beta(0),
|
||||
alpha_ptr(nullptr), beta_ptr(nullptr),
|
||||
alpha_ptr_array(alpha_ptr_array), beta_ptr_array(beta_ptr_array) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const* const* alpha_ptr_array
|
||||
):
|
||||
alpha(0), beta(0),
|
||||
alpha_ptr(nullptr), beta_ptr(nullptr),
|
||||
alpha_ptr_array(alpha_ptr_array), beta_ptr_array(nullptr) { }
|
||||
};
|
||||
|
||||
private:
|
||||
@ -140,9 +169,25 @@ public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombination(Params const ¶ms) {
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
LinearCombination(Params const ¶ms, int group_idx = 0) {
|
||||
if (params.alpha_ptr_array != nullptr && params.alpha_ptr_array[group_idx] != nullptr) {
|
||||
alpha_ = *(params.alpha_ptr_array[group_idx]);
|
||||
}
|
||||
else if (params.alpha_ptr != nullptr) {
|
||||
alpha_ = *params.alpha_ptr;
|
||||
}
|
||||
else {
|
||||
alpha_ = params.alpha;
|
||||
}
|
||||
if (params.beta_ptr_array != nullptr && params.beta_ptr_array[group_idx] != nullptr) {
|
||||
beta_ = *(params.beta_ptr_array[group_idx]);
|
||||
}
|
||||
else if (params.beta_ptr != nullptr) {
|
||||
beta_ = *params.beta_ptr;
|
||||
}
|
||||
else {
|
||||
beta_ = params.beta;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
|
||||
@ -185,8 +185,7 @@ struct CollectiveBuilder<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>) &&
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &&
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
@ -197,8 +196,7 @@ struct CollectiveBuilder<
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperative>);
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
|
||||
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
|
||||
@ -515,8 +513,7 @@ struct CollectiveBuilder<
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>>
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>
|
||||
> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
@ -534,8 +531,7 @@ struct CollectiveBuilder<
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum>);
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
|
||||
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
@ -93,8 +93,10 @@ struct CollectiveMma<
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using UnderlyingStrideA = cute::remove_pointer_t<StrideA>;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using UnderlyingStrideB = cute::remove_pointer_t<StrideB>;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
@ -149,14 +151,14 @@ struct CollectiveMma<
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
make_tensor(static_cast<InternalElementA const*>(nullptr), repeat_like(UnderlyingStrideA{}, int32_t(0)), UnderlyingStrideA{}),
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
make_tensor(static_cast<InternalElementB const*>(nullptr), repeat_like(UnderlyingStrideB{}, int32_t(0)), UnderlyingStrideB{}),
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
@ -179,16 +181,14 @@ struct CollectiveMma<
|
||||
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>;
|
||||
using StridesA = cute::conditional_t<IsGroupedGemmKernel, StrideA const*, StrideA>;
|
||||
using StridesB = cute::conditional_t<IsGroupedGemmKernel, StrideB const*, StrideB>;
|
||||
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<UnderlyingStrideA, StrideA>;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const** ptr_A;
|
||||
StridesA dA;
|
||||
StrideA dA;
|
||||
ElementB const** ptr_B;
|
||||
StridesB dB;
|
||||
StrideB dB;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
@ -197,9 +197,9 @@ struct CollectiveMma<
|
||||
TMA_B tma_load_b;
|
||||
void* tensormaps;
|
||||
InternalElementA const** ptr_A;
|
||||
StridesA dA;
|
||||
StrideA dA;
|
||||
InternalElementB const** ptr_B;
|
||||
StridesB dB;
|
||||
StrideB dB;
|
||||
};
|
||||
|
||||
//
|
||||
@ -212,30 +212,36 @@ struct CollectiveMma<
|
||||
ProblemShape problem_shapes,
|
||||
Arguments const& args,
|
||||
void* workspace) {
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(0), 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
|
||||
// These will be replaced with correct values before the initial tma load.
|
||||
auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1));
|
||||
auto init_M = get<0>(init_shape);
|
||||
auto init_N = get<1>(init_shape);
|
||||
auto init_K = get<2>(init_shape);
|
||||
// Batches/Groups are managed by using appropriate pointers to input matrices
|
||||
const uint32_t mock_L = 1;
|
||||
|
||||
// These tensor pointers are only used to create tensormap/tma desc.
|
||||
// This address to the tensor will be replaced with correct address before the initial tma load
|
||||
InternalElementA const* ptr_A_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_A);
|
||||
InternalElementB const* ptr_B_first_batch = reinterpret_cast<InternalElementA const*>(args.ptr_B);
|
||||
cudaError_t cuda_error = cudaGetLastError(); // clear previous error
|
||||
|
||||
StrideA stride_a;
|
||||
StrideB stride_b;
|
||||
UnderlyingStrideA stride_a;
|
||||
UnderlyingStrideB stride_b;
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
// Strides for Grouped Gemm will be replaced prior to the first access regardless
|
||||
stride_a = StrideA{};
|
||||
stride_b = StrideB{};
|
||||
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
|
||||
stride_a = UnderlyingStrideA{};
|
||||
stride_b = UnderlyingStrideB{};
|
||||
}
|
||||
else {
|
||||
// Tensor shapes for Ptr-Array are initialized correctly only here.
|
||||
auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0);
|
||||
init_M = get<0>(problem_shape_MNK);
|
||||
init_N = get<1>(problem_shape_MNK);
|
||||
init_K = get<2>(problem_shape_MNK);
|
||||
|
||||
stride_a = args.dA;
|
||||
stride_b = args.dB;
|
||||
}
|
||||
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), stride_a));
|
||||
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), stride_b));
|
||||
Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,mock_L), stride_a));
|
||||
Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,mock_L), stride_b));
|
||||
TMA_A tma_load_a = make_tma_copy(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
@ -287,12 +293,14 @@ struct CollectiveMma<
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
bool implementable = true;
|
||||
// Check alignment for all problem sizes
|
||||
for (int i = 0; i < problem_shapes.groups(); i++) {
|
||||
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
if (problem_shapes.is_host_problem_shape_available()) {
|
||||
// Check alignment for all problem sizes
|
||||
for (int i = 0; i < problem_shapes.groups(); i++) {
|
||||
auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), UnderlyingStrideA{});
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), UnderlyingStrideB{});
|
||||
}
|
||||
}
|
||||
|
||||
if (!implementable) {
|
||||
@ -676,6 +684,14 @@ struct CollectiveMma<
|
||||
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b,
|
||||
prob_shape_B, prob_stride_B);
|
||||
|
||||
// Convert strides to byte strides
|
||||
for (uint64_t& stride : prob_stride_A) {
|
||||
stride = (stride * sizeof_bits_v<InternalElementA>) / 8;
|
||||
}
|
||||
for (uint64_t& stride : prob_stride_B) {
|
||||
stride = (stride * sizeof_bits_v<InternalElementB>) / 8;
|
||||
}
|
||||
|
||||
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A,
|
||||
prob_shape_A,
|
||||
prob_stride_A);
|
||||
|
||||
@ -53,8 +53,7 @@ struct KernelTma { };
|
||||
struct KernelTmaWarpSpecialized { };
|
||||
struct KernelTmaWarpSpecializedPingpong { };
|
||||
struct KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelGroupTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedCooperative { };
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -67,8 +66,7 @@ struct KernelGroupTmaWarpSpecializedCooperative { };
|
||||
struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { };
|
||||
struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { };
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum : KernelGroupTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { };
|
||||
|
||||
// Policies to opt into mixed type GEMMs
|
||||
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
|
||||
@ -233,7 +231,7 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8
|
||||
template<
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||
class KernelSchedule = KernelGroupTmaWarpSpecializedCooperative
|
||||
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
|
||||
>
|
||||
struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
|
||||
constexpr static int Stages = Stages_;
|
||||
@ -241,8 +239,7 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
|
||||
using ArchTag = arch::Sm90;
|
||||
using Schedule = KernelSchedule;
|
||||
static_assert(
|
||||
cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, KernelSchedule> ||
|
||||
cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, KernelSchedule>,
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, KernelSchedule>,
|
||||
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
|
||||
};
|
||||
|
||||
|
||||
@ -71,6 +71,12 @@ struct GroupProblemShape {
|
||||
get_host_problem_shape(int32_t group_idx) const {
|
||||
return host_problem_shapes[group_idx];
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_host_problem_shape_available() {
|
||||
return host_problem_shapes != nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape_>
|
||||
@ -104,6 +110,12 @@ public:
|
||||
get_host_problem_shape(int32_t /* unused */ = 0) const {
|
||||
return problem_shape_;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_host_problem_shape_available() {
|
||||
return true;
|
||||
}
|
||||
private:
|
||||
UnderlyingProblemShape problem_shape_{};
|
||||
};
|
||||
|
||||
@ -62,8 +62,7 @@ class GemmUniversal<
|
||||
CollectiveMainloop_,
|
||||
CollectiveEpilogue_,
|
||||
TileScheduler_,
|
||||
cute::enable_if_t<cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule> ||
|
||||
cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>>
|
||||
cute::enable_if_t<cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, typename CollectiveMainloop_::DispatchPolicy::Schedule>>
|
||||
>
|
||||
{
|
||||
public:
|
||||
@ -80,7 +79,9 @@ public:
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using UnderlyingStrideA = typename CollectiveMainloop::UnderlyingStrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using UnderlyingStrideB = typename CollectiveMainloop::UnderlyingStrideB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using Schedule = typename DispatchPolicy::Schedule;
|
||||
@ -93,8 +94,10 @@ public:
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using UnderlyingStrideC = typename CollectiveEpilogue::UnderlyingStrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using UnderlyingStrideD = typename CollectiveEpilogue::UnderlyingStrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
@ -102,7 +105,7 @@ public:
|
||||
static_assert(cute::is_void_v<TileScheduler_>,
|
||||
"Ptr-Array Cooperative and Grouped Gemm Cooperative kernel only supports the default scheduler.");
|
||||
|
||||
static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v<KernelGroupTmaWarpSpecializedCooperative, Schedule>;
|
||||
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<UnderlyingStrideA, StrideA>;
|
||||
|
||||
using TileScheduler = cute::conditional_t<IsGroupedGemmKernel,
|
||||
typename detail::TileSchedulerSelector<
|
||||
@ -204,7 +207,7 @@ public:
|
||||
|
||||
void* scheduler_workspace = workspace_ptr;
|
||||
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
|
||||
args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups);
|
||||
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
@ -244,14 +247,11 @@ public:
|
||||
bool
|
||||
can_implement(Arguments const& args) {
|
||||
bool implementable = true;
|
||||
if constexpr (cute::is_base_of_v<KernelArrayTmaWarpSpecializedCooperative, Schedule>) {
|
||||
implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
|
||||
} else if constexpr (IsGroupedGemmKernel) {
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
// Group GEMM currently only supports rank-3 problem shapes
|
||||
implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3);
|
||||
}
|
||||
else {
|
||||
implementable = false;
|
||||
} else {
|
||||
implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4);
|
||||
}
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n");
|
||||
@ -269,7 +269,7 @@ public:
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
workspace_size += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
@ -297,9 +297,9 @@ public:
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
status = TileScheduler::template initialize_workspace<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
|
||||
args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_offset += TileScheduler::template get_workspace_size<typename ProblemShape::UnderlyingProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
@ -350,23 +350,20 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
|
||||
static_assert(size<0>(TileShape{}) >= 128,
|
||||
"Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension.");
|
||||
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(UnderlyingStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(UnderlyingStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(UnderlyingStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(UnderlyingStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
|
||||
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */
|
||||
enum class WarpGroupRole {
|
||||
@ -441,8 +438,6 @@ public:
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
// Purpose of maintaining this pipeline state is to make sure TMA loads have finished before doing descriptor updates
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
@ -554,7 +549,8 @@ public:
|
||||
shared_storage.tensors.mainloop
|
||||
);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(work_k_tile_count);
|
||||
// Wait for the last TMA stage to complete loading, before issuing tensormap updates
|
||||
mainloop_pipe_producer_state.advance(work_k_tile_count - 1);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive) {
|
||||
@ -570,8 +566,10 @@ public:
|
||||
if constexpr (IsGroupedGemmKernel) {
|
||||
problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(next_batch), Int<1>{});
|
||||
}
|
||||
// Wait for the last TMA stage to complete loading, before issuing tensormap updates
|
||||
mainloop_pipe_tma_consumer_state.advance(work_k_tile_count-1);
|
||||
// Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates
|
||||
// Since this state is waiting for loads to finish, it must start in the inverted phase.
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state =
|
||||
{mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()};
|
||||
mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state);
|
||||
collective_mainloop.tensormaps_perform_update(
|
||||
shared_storage.tensormaps.mainloop,
|
||||
@ -585,13 +583,9 @@ public:
|
||||
// Entire warp must do this (ie its aligned)
|
||||
collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps);
|
||||
curr_batch = next_batch;
|
||||
// Advance the TMA consumer state for the last remaining stage that was being waited for above
|
||||
mainloop_pipe_tma_consumer_state.advance(1);
|
||||
}
|
||||
else if (work_tile_info.is_valid()) { // case where batch/group didn't change between tiles
|
||||
// Advance the TMA consumer state for all the stages to be in sync
|
||||
mainloop_pipe_tma_consumer_state.advance(work_k_tile_count);
|
||||
}
|
||||
// Advance the producer state for the last remaining stage that was being waited for above
|
||||
mainloop_pipe_producer_state.advance(1);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
@ -720,6 +714,7 @@ public:
|
||||
);
|
||||
}
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -211,13 +211,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
@ -311,6 +308,7 @@ public:
|
||||
thread_idx,
|
||||
smem_buf
|
||||
);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -219,13 +219,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
@ -435,6 +432,7 @@ public:
|
||||
epi_store_pipe_producer_state_next
|
||||
);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -298,13 +298,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads.");
|
||||
@ -610,6 +607,7 @@ public:
|
||||
);
|
||||
}
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -296,13 +296,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
@ -612,6 +609,7 @@ public:
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -223,13 +223,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
enum class WarpGroupRole {
|
||||
Producer = 0,
|
||||
@ -409,6 +406,7 @@ public:
|
||||
shared_storage.tensors.epilogue
|
||||
);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -250,13 +250,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
@ -493,6 +490,7 @@ public:
|
||||
);
|
||||
}
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -257,13 +257,10 @@ public:
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) {
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
||||
@ -509,6 +506,7 @@ public:
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ private:
|
||||
|
||||
// Tracking current group, its starting linear idx and total tiles
|
||||
struct GroupInfo {
|
||||
uint64_t group = 0;
|
||||
int group_idx = 0;
|
||||
uint64_t start_linear_idx = 0;
|
||||
uint64_t total_tiles = 0;
|
||||
} current_group_info_;
|
||||
@ -115,7 +115,7 @@ public:
|
||||
GroupProblemShape problem_shapes,
|
||||
TileShape tile_shape,
|
||||
ClusterShape cluster_shape,
|
||||
[[maybe_unused]] KernelHardwareInfo const& hw_info,
|
||||
KernelHardwareInfo const& hw_info,
|
||||
Arguments const& arguments,
|
||||
[[maybe_unused]] void* workspace=nullptr,
|
||||
[[maybe_unused]] const uint32_t epilogue_subtile = 1) {
|
||||
@ -126,14 +126,16 @@ public:
|
||||
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(
|
||||
problem_shapes.groups(),
|
||||
reinterpret_cast<ProblemShape const*>(problem_shapes.host_problem_shapes),
|
||||
problem_shapes,
|
||||
hw_info,
|
||||
tile_shape, cluster_shape);
|
||||
|
||||
Params params;
|
||||
params.initialize(
|
||||
problem_blocks,
|
||||
problem_shapes.groups(),
|
||||
reinterpret_cast<ProblemShape*>(problem_shapes.problem_shapes),
|
||||
problem_shapes.problem_shapes,
|
||||
problem_shapes.host_problem_shapes,
|
||||
to_gemm_coord(tile_shape),
|
||||
to_gemm_coord(cluster_shape),
|
||||
hw_info,
|
||||
@ -144,6 +146,64 @@ public:
|
||||
return params;
|
||||
}
|
||||
|
||||
// Given the inputs, computes the physical grid we should launch.
|
||||
template<class TileShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_grid_shape(
|
||||
GroupProblemShape problem_shapes,
|
||||
TileShape tile_shape,
|
||||
ClusterShape cluster_shape,
|
||||
KernelHardwareInfo hw_info,
|
||||
Arguments arguments,
|
||||
bool truncate_by_problem_size=true) {
|
||||
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(
|
||||
problem_shapes.groups(),
|
||||
problem_shapes,
|
||||
hw_info,
|
||||
tile_shape, cluster_shape);
|
||||
|
||||
return Params::get_grid_shape(
|
||||
problem_blocks,
|
||||
to_gemm_coord(cluster_shape),
|
||||
hw_info,
|
||||
arguments.max_swizzle_size,
|
||||
arguments.raster_order,
|
||||
/* truncate_by_problem_size = */true
|
||||
);
|
||||
}
|
||||
|
||||
// Given the inputs, computes the total number of output blocks this problem will compute over
|
||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||
template<class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_tiled_cta_shape_mnl(int groups, GroupProblemShape problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) {
|
||||
uint32_t total_ctas = 0;
|
||||
uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here
|
||||
|
||||
// If host problem shapes are not provided.
|
||||
if (!problem_shapes.is_host_problem_shape_available()) {
|
||||
total_ctas = hw_info.sm_count;
|
||||
}
|
||||
// If host problem shapes are provided, make a better decision about possibility to launch smaller grid.
|
||||
else {
|
||||
for (int group = 0; group < groups; group++) {
|
||||
auto ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes.get_host_problem_shape(group)), cute::shape<0>(cta_shape)));
|
||||
auto ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes.get_host_problem_shape(group)), cute::shape<1>(cta_shape)));
|
||||
auto problem_blocks_m = round_up(ctas_along_m, cute::get<0>(cluster_shape));
|
||||
auto problem_blocks_n = round_up(ctas_along_n, cute::get<1>(cluster_shape));
|
||||
total_ctas += problem_blocks_m * problem_blocks_n;
|
||||
}
|
||||
}
|
||||
|
||||
return Params::get_tiled_cta_shape_mnl(
|
||||
to_gemm_coord(cluster_shape),
|
||||
total_ctas, cta_in_N_dim
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
@ -156,7 +216,7 @@ public:
|
||||
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
|
||||
// like blockIdx and gridDim, with __CUDA_ARCH__.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
if (params_.raster_order_ == RasterOrder::AlongN) {
|
||||
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
|
||||
current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x);
|
||||
}
|
||||
else {
|
||||
@ -165,9 +225,19 @@ public:
|
||||
|
||||
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
|
||||
|
||||
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), params_.cta_shape_.m()));
|
||||
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), params_.cta_shape_.n()));
|
||||
current_group_info_.total_tiles = cta_m * cta_n;
|
||||
uint64_t ctas_along_m, ctas_along_n;
|
||||
if (is_tuple<decltype(cute::shape<0>(params_.problem_shapes_[0]))>::value ||
|
||||
is_tuple<decltype(cute::shape<1>(params_.problem_shapes_[0]))>::value) {
|
||||
ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.m()));
|
||||
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.n()));
|
||||
}
|
||||
else {
|
||||
ctas_along_m = scheduler_params.divmod_cta_shape_m_.divide(cute::shape<0>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_m_.divisor - 1);
|
||||
ctas_along_n = scheduler_params.divmod_cta_shape_n_.divide(cute::shape<1>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_n_.divisor - 1);
|
||||
}
|
||||
auto problem_blocks_m = round_up(ctas_along_m, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.m());
|
||||
auto problem_blocks_n = round_up(ctas_along_n, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.n());
|
||||
current_group_info_.total_tiles = problem_blocks_m * problem_blocks_n;
|
||||
#else
|
||||
CUTLASS_ASSERT(false && "This line should never be reached");
|
||||
#endif
|
||||
@ -182,24 +252,22 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work_for_linear_idx(uint64_t linear_idx) {
|
||||
if (linear_idx >= scheduler_params.blocks_per_problem_) {
|
||||
if (scheduler_params.pre_processed_problem_shapes && linear_idx >= scheduler_params.blocks_across_problem_) {
|
||||
return WorkTileInfo::invalid_work_tile();
|
||||
}
|
||||
|
||||
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(linear_idx);
|
||||
|
||||
auto [work_idx_m, work_idx_n, new_group_info, valid_tile] = get_work_idx_m_and_n(blk_per_grid_dim,
|
||||
current_group_info_,
|
||||
scheduler_params.groups_,
|
||||
scheduler_params.problem_shapes_,
|
||||
scheduler_params.cta_shape_,
|
||||
scheduler_params.divmod_cluster_shape_major_,
|
||||
scheduler_params.divmod_cluster_shape_minor_,
|
||||
scheduler_params.log_swizzle_size_,
|
||||
scheduler_params.raster_order_);
|
||||
|
||||
current_group_info_ = new_group_info;
|
||||
return {work_idx_m, work_idx_n, static_cast<int>(current_group_info_.group), valid_tile};
|
||||
return get_work_idx_m_and_n(linear_idx,
|
||||
current_group_info_,
|
||||
scheduler_params.groups_,
|
||||
scheduler_params.problem_shapes_,
|
||||
scheduler_params.cta_shape_,
|
||||
scheduler_params.cluster_shape_,
|
||||
scheduler_params.divmod_cluster_shape_major_,
|
||||
scheduler_params.divmod_cluster_shape_minor_,
|
||||
scheduler_params.divmod_cta_shape_m_,
|
||||
scheduler_params.divmod_cta_shape_n_,
|
||||
scheduler_params.log_swizzle_size_,
|
||||
scheduler_params.raster_order_);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
@ -208,34 +276,62 @@ public:
|
||||
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
|
||||
}
|
||||
|
||||
// get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle
|
||||
// get work_idx_m, work_idx_n from linear_idx while applying swizzle
|
||||
static CUTLASS_DEVICE
|
||||
cute::tuple<int32_t, int32_t, struct GroupInfo, bool>
|
||||
WorkTileInfo
|
||||
get_work_idx_m_and_n(
|
||||
uint64_t blk_per_grid_dim,
|
||||
struct GroupInfo group_info,
|
||||
uint64_t linear_idx,
|
||||
struct GroupInfo& group_info,
|
||||
int32_t total_problem_groups,
|
||||
ProblemShape* problem_shapes,
|
||||
GemmCoord cta_shape,
|
||||
GemmCoord cluster_shape,
|
||||
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
|
||||
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
|
||||
FastDivmodU64 const& divmod_cta_shape_m,
|
||||
FastDivmodU64 const& divmod_cta_shape_n,
|
||||
int32_t log_swizzle_size,
|
||||
RasterOrder raster_order) {
|
||||
|
||||
bool valid_tile = true;
|
||||
int cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m()));
|
||||
int cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n()));
|
||||
uint64_t ctas_along_m, ctas_along_n;
|
||||
if (is_tuple<decltype(cute::shape<0>(problem_shapes[group_info.group_idx]))>::value ||
|
||||
is_tuple<decltype(cute::shape<1>(problem_shapes[group_info.group_idx]))>::value) {
|
||||
ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group_idx]), cta_shape.m()));
|
||||
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group_idx]), cta_shape.n()));
|
||||
}
|
||||
else {
|
||||
ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_m.divisor - 1);
|
||||
ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_n.divisor - 1);
|
||||
}
|
||||
auto problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m());
|
||||
auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n());
|
||||
group_info.total_tiles = problem_blocks_m * problem_blocks_n;
|
||||
|
||||
while (group_info.start_linear_idx + group_info.total_tiles <= linear_idx) {
|
||||
group_info.group_idx++;
|
||||
|
||||
if (group_info.group_idx >= total_problem_groups)
|
||||
return WorkTileInfo::invalid_work_tile();
|
||||
|
||||
while (group_info.start_linear_idx + group_info.total_tiles <= blk_per_grid_dim) {
|
||||
group_info.group++;
|
||||
group_info.start_linear_idx += group_info.total_tiles;
|
||||
cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m()));
|
||||
cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n()));
|
||||
group_info.total_tiles = cta_m * cta_n;
|
||||
if (is_tuple<decltype(cute::shape<0>(problem_shapes[group_info.group_idx]))>::value ||
|
||||
is_tuple<decltype(cute::shape<1>(problem_shapes[group_info.group_idx]))>::value) {
|
||||
ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group_idx]), cta_shape.m()));
|
||||
ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group_idx]), cta_shape.n()));
|
||||
}
|
||||
else {
|
||||
ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_m.divisor - 1);
|
||||
ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_n.divisor - 1);
|
||||
}
|
||||
problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m());
|
||||
problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n());
|
||||
group_info.total_tiles = problem_blocks_m * problem_blocks_n;
|
||||
}
|
||||
|
||||
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
|
||||
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim - group_info.start_linear_idx);
|
||||
uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx);
|
||||
divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim);
|
||||
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
@ -252,8 +348,13 @@ public:
|
||||
offset = cluster_id & ((1 << log_swizzle_size) - 1);
|
||||
extra = cluster_id >> log_swizzle_size;
|
||||
|
||||
uint64_t curr_group_cluster_blk_major, remainder;
|
||||
divmod_cluster_shape_major(curr_group_cluster_blk_major, remainder, cta_m);
|
||||
uint64_t curr_group_cluster_blk_major;
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_n);
|
||||
}
|
||||
else {
|
||||
curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_m);
|
||||
}
|
||||
cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major;
|
||||
cluster_idx_major = extra % curr_group_cluster_blk_major;
|
||||
|
||||
@ -265,61 +366,14 @@ public:
|
||||
cluster_major_offset);
|
||||
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
return {minor_work_idx, major_work_idx, group_info, valid_tile};
|
||||
return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile};
|
||||
}
|
||||
else {
|
||||
return {major_work_idx, minor_work_idx, group_info, valid_tile};
|
||||
return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Given the inputs, computes the total number of output blocks this problem will compute over
|
||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||
template<class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_tiled_cta_shape_mnl(int groups, ProblemShape const* problem_shapes, BlockShape cta_shape, ClusterShape cluster_shape) {
|
||||
uint32_t total_ctas = 0;
|
||||
uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here
|
||||
for (int group = 0; group < groups; group++) {
|
||||
auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group]), cute::shape<0>(cta_shape)));
|
||||
auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group]), cute::shape<1>(cta_shape)));
|
||||
total_ctas += cta_m * cta_n;
|
||||
}
|
||||
|
||||
return Params::get_tiled_cta_shape_mnl(
|
||||
to_gemm_coord(cluster_shape),
|
||||
total_ctas, cta_in_N_dim
|
||||
);
|
||||
}
|
||||
|
||||
// Given the inputs, computes the physical grid we should launch.
|
||||
template<class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_grid_shape(
|
||||
GroupProblemShape problem_shapes,
|
||||
BlockShape cta_shape,
|
||||
ClusterShape cluster_shape,
|
||||
KernelHardwareInfo hw_info,
|
||||
Arguments arguments,
|
||||
bool truncate_by_problem_size=true) {
|
||||
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(
|
||||
problem_shapes.groups(),
|
||||
reinterpret_cast<ProblemShape const*>(problem_shapes.host_problem_shapes),
|
||||
cta_shape, cluster_shape);
|
||||
|
||||
return Params::get_grid_shape(
|
||||
problem_blocks,
|
||||
to_gemm_coord(cluster_shape),
|
||||
hw_info,
|
||||
arguments.max_swizzle_size,
|
||||
arguments.raster_order,
|
||||
/* truncate_by_problem_size = */true
|
||||
);
|
||||
}
|
||||
|
||||
// Returns whether the block assigned this work should compute the epilogue for the corresponding
|
||||
// output tile. For the basic tile scheduler, this is always true.
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@ -1273,15 +1273,18 @@ struct PersistentTileSchedulerSm90GroupParams {
|
||||
|
||||
FastDivmodU64Pow2 divmod_cluster_shape_major_{};
|
||||
FastDivmodU64Pow2 divmod_cluster_shape_minor_{};
|
||||
FastDivmodU64 divmod_batch_{};
|
||||
FastDivmodU64 divmod_cta_shape_m_{};
|
||||
FastDivmodU64 divmod_cta_shape_n_{};
|
||||
|
||||
uint64_t blocks_per_problem_ = 0;
|
||||
uint64_t blocks_across_problem_ = 0;
|
||||
bool pre_processed_problem_shapes = true;
|
||||
int32_t log_swizzle_size_ = 0;
|
||||
RasterOrder raster_order_ = RasterOrder::AlongN;
|
||||
|
||||
int32_t groups_ = 0;
|
||||
ProblemShape* problem_shapes_ = nullptr;
|
||||
GemmCoord cta_shape_;
|
||||
GemmCoord cluster_shape_;
|
||||
|
||||
// Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions.
|
||||
// This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1,
|
||||
@ -1291,6 +1294,7 @@ struct PersistentTileSchedulerSm90GroupParams {
|
||||
dim3 problem_blocks,
|
||||
int32_t groups,
|
||||
ProblemShape* problem_shapes,
|
||||
ProblemShape const* host_problem_shapes,
|
||||
GemmCoord cta_shape,
|
||||
GemmCoord cluster_shape,
|
||||
KernelHardwareInfo const& hw_info,
|
||||
@ -1317,11 +1321,12 @@ struct PersistentTileSchedulerSm90GroupParams {
|
||||
groups_ = groups;
|
||||
problem_shapes_ = problem_shapes;
|
||||
cta_shape_ = cta_shape;
|
||||
cluster_shape_ = cluster_shape;
|
||||
|
||||
blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks.z;
|
||||
blocks_across_problem_ = problem_blocks.x * problem_blocks.y * problem_blocks.z;
|
||||
pre_processed_problem_shapes = (host_problem_shapes == nullptr) ? false : true;
|
||||
log_swizzle_size_ = log_swizzle_size;
|
||||
raster_order_ = raster_order;
|
||||
divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n);
|
||||
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n());
|
||||
@ -1331,6 +1336,9 @@ struct PersistentTileSchedulerSm90GroupParams {
|
||||
divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m());
|
||||
divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n());
|
||||
}
|
||||
|
||||
divmod_cta_shape_m_ = FastDivmodU64(cta_shape_.m());
|
||||
divmod_cta_shape_n_ = FastDivmodU64(cta_shape_.n());
|
||||
}
|
||||
|
||||
// Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions.
|
||||
@ -1344,8 +1352,8 @@ struct PersistentTileSchedulerSm90GroupParams {
|
||||
auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n();
|
||||
|
||||
return {
|
||||
static_cast<uint32_t>(problem_blocks_m),
|
||||
static_cast<uint32_t>(problem_blocks_n),
|
||||
static_cast<uint32_t>(cta_m),
|
||||
static_cast<uint32_t>(cta_n),
|
||||
static_cast<uint32_t>(1) // Only a single batch per group is currently supported
|
||||
};
|
||||
}
|
||||
|
||||
80
include/cutlass/version.h
Normal file
80
include/cutlass/version.h
Normal file
@ -0,0 +1,80 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#define CUTLASS_MAJOR 3
|
||||
#define CUTLASS_MINOR 4
|
||||
#define CUTLASS_PATCH 1
|
||||
|
||||
#ifdef CUTLASS_VERSIONS_GENERATED
|
||||
#include "cutlass/version_extended.h"
|
||||
#else
|
||||
#define CUTLASS_BUILD 0
|
||||
#define CUTLASS_REVISION ""
|
||||
#endif
|
||||
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
inline constexpr uint32_t getVersion() {
|
||||
return CUTLASS_VERSION;
|
||||
}
|
||||
inline constexpr uint32_t getVersionMajor() {
|
||||
return CUTLASS_MAJOR;
|
||||
}
|
||||
inline constexpr uint32_t getVersionMinor() {
|
||||
return CUTLASS_MINOR;
|
||||
}
|
||||
inline constexpr uint32_t getVersionPatch() {
|
||||
return CUTLASS_PATCH;
|
||||
}
|
||||
inline constexpr uint32_t getVersionBuild() {
|
||||
return CUTLASS_BUILD + 0;
|
||||
}
|
||||
|
||||
inline std::string getVersionString() {
|
||||
std::string version = "@CUTLASS_VERSION@";
|
||||
if (getVersionBuild()) {
|
||||
version += "." + std::to_string(getVersionBuild());
|
||||
}
|
||||
return version;
|
||||
}
|
||||
|
||||
inline std::string getGitRevision() {
|
||||
return "@CUTLASS_REVISION@";
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "nvidia-cutlass"
|
||||
version = "3.4.0.0"
|
||||
version = "3.4.1.0"
|
||||
description = "CUTLASS"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE.txt"}
|
||||
license = {text = "BSD-3-Clause"}
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
|
||||
@ -40,7 +40,7 @@ import cutlass_library
|
||||
def _cuda_install_path_from_nvcc() -> str:
|
||||
import subprocess
|
||||
# Attempt to detect CUDA_INSTALL_PATH based on location of NVCC
|
||||
result = subprocess.run(['which', 'nvcc'], capture_output=True)
|
||||
result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f'Unable to find nvcc via `which` utility.')
|
||||
|
||||
@ -121,7 +121,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '3.4.0'
|
||||
this.__version__ = '3.4.1'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
@ -169,7 +169,7 @@ def initialize_cuda_context():
|
||||
raise Exception("No CUDA devices found")
|
||||
device_id = 0
|
||||
|
||||
this._device_id = device_id
|
||||
this._device_id = int(device_id)
|
||||
|
||||
|
||||
def device_id() -> int:
|
||||
|
||||
@ -213,8 +213,12 @@ def get_mainloop_arguments_3x(
|
||||
return _MainloopArgumentsTma
|
||||
|
||||
|
||||
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue):
|
||||
if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"):
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt
|
||||
else:
|
||||
_EpilogueOutputOpParams = epilogue_functor.epilogue_type
|
||||
|
||||
if hasattr(epilogue_functor, "visitor"):
|
||||
class _EpilogueArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
|
||||
@ -157,19 +157,41 @@ class LinearCombination(EpilogueFunctorBase):
|
||||
c_element_epilogue = dtype2ctype[self.element_epilogue]
|
||||
element_epilogue = self.element_epilogue
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
class _EpilogueOutputOpParamsEVT(ctypes.Structure):
|
||||
"""
|
||||
Epilogue params when using the default linear combination of EVT, which
|
||||
does not currently use {alpha,beta}_ptr_array
|
||||
"""
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p)
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
("alpha_ptr_array", ctypes.c_void_p),
|
||||
("beta_ptr_array", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
def to_evt_params(self) -> _EpilogueOutputOpParamsEVT:
|
||||
return _EpilogueOutputOpParamsEVT(self.alpha, self.beta)
|
||||
|
||||
self.epilogue_type = _EpilogueOutputOpParams
|
||||
self.epilogue_type_evt = _EpilogueOutputOpParamsEVT
|
||||
|
||||
def emit(self):
|
||||
return super().emit(self.tag, self.template_arguments)
|
||||
|
||||
@ -241,10 +241,10 @@ class EVTFrontendBase:
|
||||
:param name: the name of the graph
|
||||
"""
|
||||
drawer = EVTGraphDrawer(self.dag_ir, name)
|
||||
if drawer.dot_available:
|
||||
try:
|
||||
for name, graph in drawer.get_dot_graph():
|
||||
graph.write_svg(f"./{name}.svg")
|
||||
else:
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"'dot' is not found in path. GraphDrawer is disabled. "
|
||||
"Please install it with 'sudo apt-get install graphviz'."
|
||||
|
||||
@ -61,22 +61,6 @@ class EVTGraphDrawer:
|
||||
self._dot_graphs = {}
|
||||
|
||||
self._dot_graphs[name] = self._to_dot(graph, name)
|
||||
self.dot_available = self._check_dot_availability()
|
||||
|
||||
def _check_dot_availability(self):
|
||||
"""
|
||||
Check if graphviz is installed
|
||||
"""
|
||||
try:
|
||||
# Run the 'dot' command and capture its output
|
||||
result = subprocess.run(
|
||||
["dot", "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
# Check if the command was successful and the output contains version information
|
||||
if result.returncode == 0 and "dot - graphviz" in result.stderr:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _get_node_style(self, node):
|
||||
template = {
|
||||
|
||||
@ -325,7 +325,7 @@ class GemmArguments2x(ArgumentBase):
|
||||
def initialize(self):
|
||||
launch_config = self.operation.rt_module.plan(self)
|
||||
|
||||
# Get the host and evice workspace
|
||||
# Get the host and device workspace
|
||||
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
||||
|
||||
if device_workspace_size > 0:
|
||||
@ -512,6 +512,18 @@ class GemmArguments3x(GemmArguments2x):
|
||||
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
|
||||
|
||||
def get_arguments(self):
|
||||
mainloop_args = get_mainloop_arguments_3x(
|
||||
self.operation.tile_description.kernel_schedule,
|
||||
self.operation.A.element,
|
||||
self.operation.B.element,
|
||||
self.operation.A.alignment,
|
||||
self.operation.B.alignment
|
||||
)
|
||||
scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
|
||||
uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
|
||||
argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
|
||||
mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
|
||||
|
||||
problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
|
||||
|
||||
if self.batch_count > 1:
|
||||
@ -539,9 +551,12 @@ class GemmArguments3x(GemmArguments2x):
|
||||
)
|
||||
|
||||
# Set of mainloop arguments needed for this kernel
|
||||
mainloop = self.operation.rt_module.mainloop_args.from_generic_mainloop_args(generic_args)
|
||||
mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
|
||||
|
||||
epilogue = self.operation.rt_module.epilogue_args(
|
||||
if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
|
||||
self.output_op = self.output_op.to_evt_params()
|
||||
|
||||
epilogue = epilogue_args(
|
||||
self.output_op,
|
||||
int(self.ptr_C),
|
||||
stride_C,
|
||||
@ -550,15 +565,15 @@ class GemmArguments3x(GemmArguments2x):
|
||||
)
|
||||
|
||||
# Set hardware info
|
||||
hw_info = self.operation.rt_module.hw_info(0, device_sm_count())
|
||||
hw_info_ = hw_info(0, device_sm_count())
|
||||
|
||||
self.arguments = self.operation.argument_type(
|
||||
self.arguments = argument_type(
|
||||
int(self.gemm_mode),
|
||||
problem_size_,
|
||||
mainloop,
|
||||
epilogue,
|
||||
hw_info,
|
||||
self.operation.rt_module.scheduler_args
|
||||
hw_info_,
|
||||
scheduler_args
|
||||
)
|
||||
return self.arguments
|
||||
|
||||
@ -1119,6 +1134,10 @@ extern "C" {
|
||||
|
||||
using GemmType = ${operation_name}_base;
|
||||
|
||||
bool ${operation_name}_uses_default_epilogue() {
|
||||
return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
|
||||
}
|
||||
|
||||
// Get the workspace size
|
||||
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
|
||||
return GemmType::get_workspace_size(*argument);
|
||||
@ -1163,19 +1182,10 @@ extern "C" {
|
||||
"get_grid_shape": dim3_,
|
||||
"get_block_shape": dim3_,
|
||||
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
|
||||
"get_kernel_workspace_size": ctypes.c_uint64
|
||||
"get_kernel_workspace_size": ctypes.c_uint64,
|
||||
"uses_default_epilogue": ctypes.c_bool,
|
||||
}
|
||||
self.emitter = EmitGemmUniversalInstance3x("_type")
|
||||
self.mainloop_args = get_mainloop_arguments_3x(
|
||||
operation.tile_description.kernel_schedule,
|
||||
operation.A.element,
|
||||
operation.B.element,
|
||||
operation.A.alignment,
|
||||
operation.B.alignment
|
||||
)
|
||||
self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler)
|
||||
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(
|
||||
self.mainloop_args, operation.epilogue_functor, self.scheduler_args)
|
||||
|
||||
def get_device_workspace_size(self, arguments: GemmArguments3x):
|
||||
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='3.4.0',
|
||||
version='3.4.1',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='3.4.0',
|
||||
version='3.4.1',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -56,9 +56,6 @@
|
||||
#include "gemm_testbed_3x_evt.hpp"
|
||||
#include "sm90_evt_operations.hpp"
|
||||
|
||||
|
||||
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
@ -132,7 +129,7 @@ bool testEVTAuxStoreWithoutD() {
|
||||
D_block.reset(m * n);
|
||||
aux_store_D_block.reset(m * n);
|
||||
Gemm gemm_op_base;
|
||||
|
||||
|
||||
auto stride_A = cutlass::make_cute_packed_stride(
|
||||
typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{}));
|
||||
auto stride_B = cutlass::make_cute_packed_stride(
|
||||
@ -141,7 +138,7 @@ bool testEVTAuxStoreWithoutD() {
|
||||
typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{}));
|
||||
auto stride_D = cutlass::make_cute_packed_stride(
|
||||
typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{}));
|
||||
|
||||
|
||||
auto arguments_base = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
@ -178,12 +175,12 @@ bool testEVTAuxStoreWithoutD() {
|
||||
/*hw_info=*/{},
|
||||
/*scheduler_args=*/{}
|
||||
};
|
||||
|
||||
|
||||
constexpr float beta [[maybe_unused]] = 1.0;
|
||||
constexpr float alpha [[maybe_unused]] = 1.0;
|
||||
|
||||
|
||||
using ElementC = typename GemmWithoutD::ElementC;
|
||||
|
||||
|
||||
if constexpr (not has_c) {
|
||||
arguments_base.epilogue.thread = {
|
||||
// binary op : alpha * acc
|
||||
@ -282,7 +279,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
@ -324,10 +321,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -352,7 +349,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
@ -394,10 +391,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -467,7 +464,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -492,7 +489,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
@ -534,10 +531,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -562,7 +559,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
@ -604,10 +601,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -677,7 +674,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
|
||||
@ -35,7 +35,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/stride.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user