fix gather example (#574)
This commit is contained in:
parent
0b8cacd6f1
commit
5d05808072
@ -40,12 +40,12 @@
|
|||||||
// for (int j = 0; j < options.index_size; ++j) {
|
// for (int j = 0; j < options.index_size; ++j) {
|
||||||
// int b_c_d_col = tensor_indices.at({j, 0});
|
// int b_c_d_col = tensor_indices.at({j, 0});
|
||||||
//
|
//
|
||||||
// for (int k = 0; k < options.index_size; ++k) {
|
// for (int k = 0; k < problem_size.k(); ++k) {
|
||||||
// int a_col = tensor_indices.at({k, 0});
|
|
||||||
// tensor_d_ref.at({i, b_c_d_col}) +=
|
// tensor_d_ref.at({i, b_c_d_col}) +=
|
||||||
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
|
// alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
// }
|
||||||
//
|
//
|
||||||
// Note that the index vector contains unique random integers with max to be N - 1
|
// Note that the index vector contains unique random integers with max to be N - 1
|
||||||
//
|
//
|
||||||
@ -257,7 +257,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal<ElementInputA,
|
|||||||
cutlass::arch::OpMultiplyAdd,
|
cutlass::arch::OpMultiplyAdd,
|
||||||
cutlass::ComplexTransform::kNone,
|
cutlass::ComplexTransform::kNone,
|
||||||
cutlass::ComplexTransform::kNone,
|
cutlass::ComplexTransform::kNone,
|
||||||
true, /*GatherA*/
|
false, /*GatherA*/
|
||||||
true, /*GatherB*/
|
true, /*GatherB*/
|
||||||
true /*ScatterD*/
|
true /*ScatterD*/
|
||||||
>;
|
>;
|
||||||
@ -273,13 +273,13 @@ int run(Options &options) {
|
|||||||
// Create a tuple of problem size for matrix multiplication
|
// Create a tuple of problem size for matrix multiplication
|
||||||
cutlass::gemm::GemmCoord problem_size_real(problem_size.m(),
|
cutlass::gemm::GemmCoord problem_size_real(problem_size.m(),
|
||||||
options.index_size,
|
options.index_size,
|
||||||
options.index_size);
|
problem_size.k());
|
||||||
|
|
||||||
// Initialize tensors using CUTLASS helper functions
|
// Initialize tensors using CUTLASS helper functions
|
||||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||||
cutlass::make_Coord(options.index_size, problem_size.n())); // <- Create matrix B with dimensions K x N
|
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_scattered(
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d_scattered(
|
||||||
@ -353,7 +353,7 @@ int run(Options &options) {
|
|||||||
tensor_b.layout().stride(),
|
tensor_b.layout().stride(),
|
||||||
tensor_c.layout().stride(),
|
tensor_c.layout().stride(),
|
||||||
tensor_d_scattered.layout().stride(),
|
tensor_d_scattered.layout().stride(),
|
||||||
tensor_indices.device_data(), // <- pointer to index vector to gather A on device
|
nullptr, // <- pointer to index vector to gather A on device
|
||||||
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
|
tensor_indices.device_data(), // <- pointer to index vector to gather B on device
|
||||||
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device
|
tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device
|
||||||
|
|
||||||
@ -388,10 +388,9 @@ int run(Options &options) {
|
|||||||
for (int j = 0; j < options.index_size; ++j) {
|
for (int j = 0; j < options.index_size; ++j) {
|
||||||
int b_c_d_col = tensor_indices.at({j, 0});
|
int b_c_d_col = tensor_indices.at({j, 0});
|
||||||
|
|
||||||
for (int k = 0; k < options.index_size; ++k) {
|
for (int k = 0; k < problem_size.k(); ++k) {
|
||||||
int a_col = tensor_indices.at({k, 0});
|
|
||||||
tensor_d_ref.at({i, b_c_d_col}) +=
|
tensor_d_ref.at({i, b_c_d_col}) +=
|
||||||
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
|
alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col});
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
|
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
|
||||||
|
Loading…
Reference in New Issue
Block a user