fix a broken sparse gemm example. found by the community.

This commit is contained in:
Manikandan Ananth 2021-04-07 13:32:55 -07:00
parent 4839b6cb61
commit 26556d7206

View File

@ -111,7 +111,8 @@ using Gemm = cutlass::gemm::device::SparseGemm<ElementInputA,
// Data type and layout of meta data matrix E can be inferred from template Gemm.
using ElementInputE = typename Gemm::ElementE;
using LayoutInputE = typename Gemm::LayoutE;
using LayoutInputE = cutlass::layout::RowMajor;
using ReorderedLayoutInputE = typename Gemm::LayoutE;
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
// 50% Sparsity on Ampere
@ -151,27 +152,27 @@ int run() {
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Same size as the above. The above one needs to be reordered and stored in this one.
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e_reordered(
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(1),
ElementInputA(-1),
ElementInputA(2),
ElementInputA(-2),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(1),
ElementInputB(-1),
ElementInputB(2),
ElementInputB(-2),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(1),
ElementOutput(-1),
ElementOutput(2),
ElementOutput(-2),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomSparseMeta(
tensor_e.host_view(),
@ -210,7 +211,7 @@ int run() {
tensor_b.device_ref(), // <- reference to matrix B on device
tensor_c.device_ref(), // <- reference to matrix C on device
tensor_d.device_ref(), // <- reference to matrix D on device
tensor_e.device_ref(), // <- reference to matrix E on device
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor