diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index f22e235f..f8fbcc33 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -40,12 +40,12 @@ // for (int j = 0; j < options.index_size; ++j) { // int b_c_d_col = tensor_indices.at({j, 0}); // -// for (int k = 0; k < options.index_size; ++k) { -// int a_col = tensor_indices.at({k, 0}); +// for (int k = 0; k < problem_size.k(); ++k) { // tensor_d_ref.at({i, b_c_d_col}) += -// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}); +// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); // } // } +// } // // Note that the index vector contains unique random integers with max to be N - 1 // @@ -257,7 +257,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal; @@ -273,13 +273,13 @@ int run(Options &options) { // Create a tuple of problem size for matrix multiplication cutlass::gemm::GemmCoord problem_size_real(problem_size.m(), options.index_size, - options.index_size); + problem_size.k()); // Initialize tensors using CUTLASS helper functions cutlass::HostTensor tensor_a( problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b( - cutlass::make_Coord(options.index_size, problem_size.n())); // <- Create matrix B with dimensions K x N + problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c( problem_size.mn()); // <- Create matrix C with dimensions M x N cutlass::HostTensor tensor_d_scattered( @@ -353,7 +353,7 @@ int run(Options &options) { tensor_b.layout().stride(), tensor_c.layout().stride(), tensor_d_scattered.layout().stride(), - tensor_indices.device_data(), // <- pointer to index vector to gather A on device + nullptr, // <- pointer to index vector to gather A on device tensor_indices.device_data(), // <- pointer to index vector to gather B on device tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device @@ -388,10 +388,9 @@ int run(Options &options) { for (int j = 0; j < options.index_size; ++j) { int b_c_d_col = tensor_indices.at({j, 0}); - for (int k = 0; k < options.index_size; ++k) { - int a_col = tensor_indices.at({k, 0}); + for (int k = 0; k < problem_size.k(); ++k) { tensor_d_ref.at({i, b_c_d_col}) += - alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}); + alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); } tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));