fix gather example (#574)
This commit is contained in:
parent
0b8cacd6f1
commit
5d05808072
@ -40,10 +40,10 @@
|
||||
// for (int j = 0; j < options.index_size; ++j) {
|
||||
// int b_c_d_col = tensor_indices.at({j, 0});
|
||||
//
|
||||
// for (int k = 0; k < options.index_size; ++k) {
|
||||
// int a_col = tensor_indices.at({k, 0});
|
||||
// for (int k = 0; k < problem_size.k(); ++k) {
|
||||
// tensor_d_ref.at({i, b_c_d_col}) +=
|
||||
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
|
||||
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
@ -257,7 +257,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
|
||||
cutlass::arch::OpMultiplyAdd,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
true, /*GatherA*/
|
||||
false, /*GatherA*/
|
||||
true, /*GatherB*/
|
||||
true /*ScatterD*/
|
||||
>;
|
||||
@ -273,13 +273,13 @@ int run(Options &options) {
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size_real(problem_size.m(),
|
||||
options.index_size,
|
||||
options.index_size);
|
||||
problem_size.k());
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
cutlass::make_Coord(options.index_size, problem_size.n())); // <- Create matrix B with dimensions K x N
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_scattered(
|
||||
@ -353,7 +353,7 @@ int run(Options &options) {
|
||||
tensor_b.layout().stride(),
|
||||
tensor_c.layout().stride(),
|
||||
tensor_d_scattered.layout().stride(),
|
||||
tensor_indices.device_data(), // <- pointer to index vector to gather A on device
|
||||
nullptr, // <- pointer to index vector to gather A on device
|
||||
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
|
||||
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device
|
||||
|
||||
@ -388,10 +388,9 @@ int run(Options &options) {
|
||||
for (int j = 0; j < options.index_size; ++j) {
|
||||
int b_c_d_col = tensor_indices.at({j, 0});
|
||||
|
||||
for (int k = 0; k < options.index_size; ++k) {
|
||||
int a_col = tensor_indices.at({k, 0});
|
||||
for (int k = 0; k < problem_size.k(); ++k) {
|
||||
tensor_d_ref.at({i, b_c_d_col}) +=
|
||||
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
|
||||
alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
|
||||
}
|
||||
|
||||
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
|
||||
|
Loading…
Reference in New Issue
Block a user