fix a broken sparse gemm example. found by the community.
This commit is contained in:
parent
4839b6cb61
commit
26556d7206
@ -111,7 +111,8 @@ using Gemm = cutlass::gemm::device::SparseGemm<ElementInputA,
|
|||||||
|
|
||||||
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
// Data type and layout of meta data matrix E can be inferred from template Gemm.
|
||||||
using ElementInputE = typename Gemm::ElementE;
|
using ElementInputE = typename Gemm::ElementE;
|
||||||
using LayoutInputE = typename Gemm::LayoutE;
|
using LayoutInputE = cutlass::layout::RowMajor;
|
||||||
|
using ReorderedLayoutInputE = typename Gemm::LayoutE;
|
||||||
|
|
||||||
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h
|
||||||
// 50% Sparsity on Ampere
|
// 50% Sparsity on Ampere
|
||||||
@ -151,27 +152,27 @@ int run() {
|
|||||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e(
|
||||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||||
// Same size as the above. The above one needs to be reordered and stored in this one.
|
// Same size as the above. The above one needs to be reordered and stored in this one.
|
||||||
cutlass::HostTensor<ElementInputE, LayoutInputE> tensor_e_reordered(
|
cutlass::HostTensor<ElementInputE, ReorderedLayoutInputE> tensor_e_reordered(
|
||||||
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE));
|
||||||
|
|
||||||
// Fill input and output matrices on host using CUTLASS helper functions
|
// Fill input and output matrices on host using CUTLASS helper functions
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
tensor_a.host_view(),
|
tensor_a.host_view(),
|
||||||
1,
|
1,
|
||||||
ElementInputA(1),
|
ElementInputA(2),
|
||||||
ElementInputA(-1),
|
ElementInputA(-2),
|
||||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
tensor_b.host_view(),
|
tensor_b.host_view(),
|
||||||
1,
|
1,
|
||||||
ElementInputB(1),
|
ElementInputB(2),
|
||||||
ElementInputB(-1),
|
ElementInputB(-2),
|
||||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||||
cutlass::reference::host::TensorFillRandomUniform(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
tensor_c.host_view(),
|
tensor_c.host_view(),
|
||||||
1,
|
1,
|
||||||
ElementOutput(1),
|
ElementOutput(2),
|
||||||
ElementOutput(-1),
|
ElementOutput(-2),
|
||||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||||
cutlass::reference::host::TensorFillRandomSparseMeta(
|
cutlass::reference::host::TensorFillRandomSparseMeta(
|
||||||
tensor_e.host_view(),
|
tensor_e.host_view(),
|
||||||
@ -210,7 +211,7 @@ int run() {
|
|||||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||||
tensor_c.device_ref(), // <- reference to matrix C on device
|
tensor_c.device_ref(), // <- reference to matrix C on device
|
||||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||||
tensor_e.device_ref(), // <- reference to matrix E on device
|
tensor_e_reordered.device_ref(), // <- reference to matrix E on device
|
||||||
{alpha, beta}, // <- tuple of alpha and beta
|
{alpha, beta}, // <- tuple of alpha and beta
|
||||||
split_k_slices}; // <- k-dimension split factor
|
split_k_slices}; // <- k-dimension split factor
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user