fix gather example (#574)

This commit is contained in:
Shang Zhang 2022-07-19 13:18:17 -07:00 committed by GitHub
parent 0b8cacd6f1
commit 5d05808072
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -40,12 +40,12 @@
// for (int j = 0; j < options.index_size; ++j) {
// int b_c_d_col = tensor_indices.at({j, 0});
//
// for (int k = 0; k < options.index_size; ++k) {
// int a_col = tensor_indices.at({k, 0});
// for (int k = 0; k < problem_size.k(); ++k) {
// tensor_d_ref.at({i, b_c_d_col}) +=
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
// }
// }
// }
//
// Note that the index vector contains unique random integers with max to be N - 1
//
@ -257,7 +257,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
true, /*GatherA*/
false, /*GatherA*/
true, /*GatherB*/
true /*ScatterD*/
>;
@ -273,13 +273,13 @@ int run(Options &options) {
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size_real(problem_size.m(),
options.index_size,
options.index_size);
problem_size.k());
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
cutlass::make_Coord(options.index_size, problem_size.n())); // <- Create matrix B with dimensions K x N
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_scattered(
@ -353,7 +353,7 @@ int run(Options &options) {
tensor_b.layout().stride(),
tensor_c.layout().stride(),
tensor_d_scattered.layout().stride(),
tensor_indices.device_data(), // <- pointer to index vector to gather A on device
nullptr, // <- pointer to index vector to gather A on device
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device
@ -388,10 +388,9 @@ int run(Options &options) {
for (int j = 0; j < options.index_size; ++j) {
int b_c_d_col = tensor_indices.at({j, 0});
for (int k = 0; k < options.index_size; ++k) {
int a_col = tensor_indices.at({k, 0});
for (int k = 0; k < problem_size.k(); ++k) {
tensor_d_ref.at({i, b_c_d_col}) +=
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
}
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));