Merge pull request #1713 from NVIDIA/351_sparse_update

update 3.5.1 readme/changelog
This commit is contained in:
Dustyn Blasig 2024-08-15 11:44:49 -05:00 committed by GitHub
commit 865be73a97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 34 additions and 498 deletions

View File

@ -7,10 +7,15 @@
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and
[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) leveraging 2:4 structured sparsity and [support for LLM friendly tile sizes](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu).
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence:
+ [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411).
+ [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
+ [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452).
+ [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456).
- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py).
- Support for residual add (beta != 0) in convolution kernels.
- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt).
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md).
- Better support for MSVC as a host compiler.

View File

@ -51,10 +51,11 @@ CUTLASS 3.5.1 is an update to CUTLASS adding:
- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and
[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu).
- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu).
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) leveraging 2:4 structured sparsity and [support for LLM friendly tile sizes](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu).
- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inference.
- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs.
- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py).
- Support for residual add (beta != 0) in convolution kernels.
- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output.
- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt).
- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md).
- Better support for MSVC as a host compiler.

View File

@ -182,7 +182,7 @@ TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f32, 32x256x64_32x64x64) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 5>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
@ -220,7 +220,7 @@ TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f32, 256x32x64_64x32x64) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 5>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
@ -257,7 +257,7 @@ TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f32, 16x256x64_16x64x64) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 5>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}
@ -295,7 +295,7 @@ TEST(SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f32, 256x16x64_64x16x64) {
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 5>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllGemm<Gemm>());
}

View File

@ -320,24 +320,6 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x32x64_64x32x64) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x32x128_64x32x128) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 32, 128>,
cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x128x128_32x32x128) {
using ElementOutput = float;
using ElementAccumulator = float;
@ -351,45 +333,7 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x128x128_32x32x128)
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900)
TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x256x64_32x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
#endif
TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 32x256x128_32x64x128) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t,
cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 128>,
cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -461,10 +405,11 @@ TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x16x128_64x16x128)
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED

View File

@ -279,45 +279,7 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x128x128_32x32x128)
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900)
TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x256x64_32x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
#endif
TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 32x256x128_32x64x128) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 128>,
cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -376,24 +338,6 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x32x64_64x32x64) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x32x128_64x32x128) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 32, 128>,
cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x16x64_32x16x64) {
using ElementOutput = float;
using ElementAccumulator = float;
@ -448,24 +392,6 @@ TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x16x64_64x16x64) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x16x128_64x16x128) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 16, 128>,
cutlass::gemm::GemmShape<64, 16, 128>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED

View File

@ -449,75 +449,12 @@ TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x128x64_32x32x64) {
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
6
3
>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900)
TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x256x32_32x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
6
>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
#endif
TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 32x256x64_32x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 64>,
cutlass::gemm::GemmShape<32, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
6
>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED)

View File

@ -449,71 +449,7 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x128x64_32x32x64) {
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
6
>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 900)
TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x256x32_32x64x32) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
6
>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
#endif
TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 32x256x64_32x64x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 64>,
cutlass::gemm::GemmShape<32, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
6
3
>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
@ -612,37 +548,6 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x32x32_64x32x32) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x32x64_64x32x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 32, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
6
>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x16x32_32x16x32) {
using ElementOutput = float;
@ -736,37 +641,6 @@ TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x16x32_64x16x32) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x16x64_64x16x64) {
using ElementOutput = float;
using ElementAccumulator = float;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
float,
cutlass::layout::RowMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 16, 64>,
cutlass::gemm::GemmShape<64, 16, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
6
>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED)

View File

@ -275,7 +275,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x128x512_32x32x512) {
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -313,26 +313,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x256x256_32x64x256) {
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 32x256x512_32x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 512>,
cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 128>,
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -351,26 +332,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 16x128x512_16x32x512) {
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 16x256x512_16x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<16, 256, 512>,
cutlass::gemm::GemmShape<16, 64, 512>, cutlass::gemm::GemmShape<16, 8, 128>,
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -408,7 +370,7 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x32x512_32x32x512) {
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -432,25 +394,6 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x32x256_64x32x256) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x32x512_64x32x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 32, 512>,
cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>,
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x16x256_32x16x256) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
@ -508,25 +451,6 @@ TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x16x256_16x64x256) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x16x512_16x64x512) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 16, 512>,
cutlass::gemm::GemmShape<64, 16, 512>, cutlass::gemm::GemmShape<16, 8, 128>,
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED)

View File

@ -294,7 +294,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x128x256_32x32x256) {
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -313,26 +313,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x256x128_32x64x128) {
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 32x256x256_32x64x256) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
int8_t, cutlass::layout::RowMajor, int8_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 256, 256>,
cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 64>,
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -351,26 +332,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 16x128x256_16x32x256) {
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 16x128x256_32x32x256) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
int8_t, cutlass::layout::RowMajor, int8_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<16, 128, 256>,
cutlass::gemm::GemmShape<16, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>,
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -408,7 +370,7 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x32x256_32x32x256) {
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
@ -432,25 +394,6 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x32x128_64x32x128) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x32x256_64x32x256) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
int8_t, cutlass::layout::RowMajor, int8_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 32, 256>,
cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>,
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x16x128_32x16x128) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
@ -508,25 +451,6 @@ TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x16x128_64x16x128) {
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x16x256_64x16x256) {
using ElementOutput = int32_t;
using ElementAccumulator = int32_t;
using ElementCompute = int32_t;
using Gemm = cutlass::gemm::device::GemmSparseUniversal<
int8_t, cutlass::layout::RowMajor, int8_t,
cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor,
ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 16, 256>,
cutlass::gemm::GemmShape<64, 16, 256>, cutlass::gemm::GemmShape<16, 8, 64>,
cutlass::epilogue::thread::LinearCombinationClamp<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementCompute>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>;
EXPECT_TRUE(test::gemm::device::TestAllSparseGemm<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED)

View File

@ -141,10 +141,10 @@ struct MultistageTestbed {
ElementCompute alpha = ElementCompute(1),
ElementCompute beta = ElementCompute(0)) {
// Waives test if CUDA device is insufficient
if (!sufficient()) {
return true;
}
// Waives test if CUDA device is insufficient
if (!sufficient()) {
return true;
}
//
// Allocate the GEMM workspace

View File

@ -78,7 +78,7 @@ TEST(SM80_Device_Syrk_f32n_f32t_l_tensor_op_fast_f32, 128x256x32_64x64x32) {
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
3
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());

View File

@ -78,7 +78,7 @@ TEST(SM80_Device_Syrk_f32t_f32t_l_tensor_op_fast_f32, 128x256x32_64x64x32) {
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
3
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());

View File

@ -78,7 +78,7 @@ TEST(SM80_Device_Syrk_tf32n_f32t_l_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
3
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());

View File

@ -78,7 +78,7 @@ TEST(SM80_Device_Syrk_tf32t_f32t_l_tensor_op_f32, 128x256x32_64x64x32) {
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4
3
>;
EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal<RankK>());