Update gather_scatter_fusion.cu

Correct the reference code in gather/scatter example to put bias add in the correct place.
This commit is contained in:
Haicheng Wu 2022-05-18 13:15:25 -04:00 committed by GitHub
parent d6f58b2d14
commit 858c735856
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -43,7 +43,7 @@
// for (int k = 0; k < options.index_size; ++k) { // for (int k = 0; k < options.index_size; ++k) {
// int a_col = tensor_indices.at({k, 0}); // int a_col = tensor_indices.at({k, 0});
// tensor_d_ref.at({i, b_c_d_col}) += // tensor_d_ref.at({i, b_c_d_col}) +=
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}) + beta * tensor_c.at({i, b_c_d_col}); // alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
// } // }
// } // }
// //
@ -229,8 +229,7 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
// the vector width of math instructions in // the vector width of math instructions in
// epilogue too // epilogue too
ElementAccumulator, // <- data type of accumulator ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue, // <- data type for alpha in linear combination function ElementComputeEpilogue>; // <- data type for alpha in linear combination function
cutlass::epilogue::thread::ScaleType::Nothing>; // <- C
// Number of pipelines you want to use // Number of pipelines you want to use
constexpr int NumStages = 5; constexpr int NumStages = 5;
@ -301,8 +300,12 @@ int run(Options &options) {
ElementInputA(-8), ElementInputA(-8),
0); // <- Fill matrix B on host with uniform-distribution random data 0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFill( cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view()); // <- Fill matrix C on host with zeros tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill( cutlass::reference::host::TensorFill(
tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros
@ -387,8 +390,10 @@ int run(Options &options) {
for (int k = 0; k < options.index_size; ++k) { for (int k = 0; k < options.index_size; ++k) {
int a_col = tensor_indices.at({k, 0}); int a_col = tensor_indices.at({k, 0});
tensor_d_ref.at({i, b_c_d_col}) += tensor_d_ref.at({i, b_c_d_col}) +=
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}) + beta * tensor_c.at({i, b_c_d_col}); alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
} }
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
} }
} }