fix a broken sparse gemm example. found by the community.
This commit is contained in:
parent
4839b6cb61
commit
26556d7206
@ -111,7 +111,8 @@ using Gemm = cutlass::gemm::device::SparseGemm<ElementInputA,
|
||||
|
||||
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||
using ElementInputE = typename Gemm::ElementE;
|
||||
using LayoutInputE = typename Gemm::LayoutE;
|
||||
using LayoutInputE = cutlass::layout::RowMajor;
|
||||
using ReorderedLayoutInputE = typename Gemm::LayoutE;
|
||||
|
||||
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||
// 50% Sparsity on Ampere
|
||||
@ -151,27 +152,27 @@ int run() {
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e_reordered(
|
||||
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(1),
|
||||
ElementInputA(-1),
|
||||
ElementInputA(2),
|
||||
ElementInputA(-2),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(1),
|
||||
ElementInputB(-1),
|
||||
ElementInputB(2),
|
||||
ElementInputB(-2),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(1),
|
||||
ElementOutput(-1),
|
||||
ElementOutput(2),
|
||||
ElementOutput(-2),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||
tensor_e.host_view(),
|
||||
@ -210,7 +211,7 @@ int run() {
|
||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||
tensor_c.device_ref(), // <- reference to matrix C on device
|
||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||
tensor_e.device_ref(), // <- reference to matrix E on device
|
||||
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user