Add extended wgmma shapes for all data types (#1666)

This commit is contained in:
Sergey Klevtsov 2024-07-31 15:33:14 -07:00 committed by GitHub
parent 1f2b590da6
commit 36cbfcf483
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37395 additions and 267 deletions

View File

@ -310,6 +310,13 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
"Enable PTX mma instruction for collective matrix multiply operations.") "Enable PTX mma instruction for collective matrix multiply operations.")
set(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES OFF CACHE BOOL
"Enable an extended set of SM90 WGMMA instruction shapes (may lead to increased compilation times)")
if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES)
message(STATUS "Enabled extended SM90 WGMMA instruction shapes")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
endif()
# #
# NOTE: running with asan and CUDA requires the following environment variable: # NOTE: running with asan and CUDA requires the following environment variable:
# #

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -147,10 +147,19 @@ using _12 = Int<12>;
using _16 = Int<16>; using _16 = Int<16>;
using _24 = Int<24>; using _24 = Int<24>;
using _32 = Int<32>; using _32 = Int<32>;
using _48 = Int<48>;
using _64 = Int<64>; using _64 = Int<64>;
using _80 = Int<80>;
using _96 = Int<96>; using _96 = Int<96>;
using _112 = Int<112>;
using _128 = Int<128>; using _128 = Int<128>;
using _144 = Int<144>;
using _160 = Int<160>;
using _176 = Int<176>;
using _192 = Int<192>; using _192 = Int<192>;
using _208 = Int<208>;
using _224 = Int<224>;
using _240 = Int<240>;
using _256 = Int<256>; using _256 = Int<256>;
using _384 = Int<384>; using _384 = Int<384>;
using _512 = Int<512>; using _512 = Int<512>;