Update gather_scatter_fusion.cu

Correct the reference code in gather/scatter example to put bias add in the correct place.
This commit is contained in:
Haicheng Wu 2022-05-18 13:15:25 -04:00 committed by GitHub
parent d6f58b2d14
commit 858c735856
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -43,7 +43,7 @@
// for (int k = 0; k < options.index_size; ++k) {
// int a_col = tensor_indices.at({k, 0});
// tensor_d_ref.at({i, b_c_d_col}) +=
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}) + beta * tensor_c.at({i, b_c_d_col});
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
// }
// }
//
@ -229,8 +229,7 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue, // <- data type for alpha in linear combination function
cutlass::epilogue::thread::ScaleType::Nothing>; // <- C
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 5;
@ -301,8 +300,12 @@ int run(Options &options) {
ElementInputA(-8),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_c.host_view()); // <- Fill matrix C on host with zeros
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros
@ -387,8 +390,10 @@ int run(Options &options) {
for (int k = 0; k < options.index_size; ++k) {
int a_col = tensor_indices.at({k, 0});
tensor_d_ref.at({i, b_c_d_col}) +=
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}) + beta * tensor_c.at({i, b_c_d_col});
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
}
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
}
}