Add extended wgmma shapes for all data types (#1666)
This commit is contained in:
parent
1f2b590da6
commit
36cbfcf483
@ -310,6 +310,13 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_
|
|||||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||||
"Enable PTX mma instruction for collective matrix multiply operations.")
|
"Enable PTX mma instruction for collective matrix multiply operations.")
|
||||||
|
|
||||||
|
set(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES OFF CACHE BOOL
|
||||||
|
"Enable an extended set of SM90 WGMMA instruction shapes (may lead to increased compilation times)")
|
||||||
|
if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES)
|
||||||
|
message(STATUS "Enabled extended SM90 WGMMA instruction shapes")
|
||||||
|
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# NOTE: running with asan and CUDA requires the following environment variable:
|
# NOTE: running with asan and CUDA requires the following environment variable:
|
||||||
#
|
#
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -147,10 +147,19 @@ using _12 = Int<12>;
|
|||||||
using _16 = Int<16>;
|
using _16 = Int<16>;
|
||||||
using _24 = Int<24>;
|
using _24 = Int<24>;
|
||||||
using _32 = Int<32>;
|
using _32 = Int<32>;
|
||||||
|
using _48 = Int<48>;
|
||||||
using _64 = Int<64>;
|
using _64 = Int<64>;
|
||||||
|
using _80 = Int<80>;
|
||||||
using _96 = Int<96>;
|
using _96 = Int<96>;
|
||||||
|
using _112 = Int<112>;
|
||||||
using _128 = Int<128>;
|
using _128 = Int<128>;
|
||||||
|
using _144 = Int<144>;
|
||||||
|
using _160 = Int<160>;
|
||||||
|
using _176 = Int<176>;
|
||||||
using _192 = Int<192>;
|
using _192 = Int<192>;
|
||||||
|
using _208 = Int<208>;
|
||||||
|
using _224 = Int<224>;
|
||||||
|
using _240 = Int<240>;
|
||||||
using _256 = Int<256>;
|
using _256 = Int<256>;
|
||||||
using _384 = Int<384>;
|
using _384 = Int<384>;
|
||||||
using _512 = Int<512>;
|
using _512 = Int<512>;
|
||||||
|
Loading…
Reference in New Issue
Block a user