Update gather_scatter_fusion.cu
Correct the reference code in gather/scatter example to put bias add in the correct place.
This commit is contained in:
parent
d6f58b2d14
commit
858c735856
@ -43,7 +43,7 @@
|
|||||||
// for (int k = 0; k < options.index_size; ++k) {
|
// for (int k = 0; k < options.index_size; ++k) {
|
||||||
// int a_col = tensor_indices.at({k, 0});
|
// int a_col = tensor_indices.at({k, 0});
|
||||||
// tensor_d_ref.at({i, b_c_d_col}) +=
|
// tensor_d_ref.at({i, b_c_d_col}) +=
|
||||||
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}) + beta * tensor_c.at({i, b_c_d_col});
|
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
@ -229,8 +229,7 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
|||||||
// the vector width of math instructions in
|
// the vector width of math instructions in
|
||||||
// epilogue too
|
// epilogue too
|
||||||
ElementAccumulator, // <- data type of accumulator
|
ElementAccumulator, // <- data type of accumulator
|
||||||
ElementComputeEpilogue, // <- data type for alpha in linear combination function
|
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
|
||||||
cutlass::epilogue::thread::ScaleType::Nothing>; // <- C
|
|
||||||
|
|
||||||
// Number of pipelines you want to use
|
// Number of pipelines you want to use
|
||||||
constexpr int NumStages = 5;
|
constexpr int NumStages = 5;
|
||||||
@ -301,8 +300,12 @@ int run(Options &options) {
|
|||||||
ElementInputA(-8),
|
ElementInputA(-8),
|
||||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||||
|
|
||||||
cutlass::reference::host::TensorFill(
|
cutlass::reference::host::TensorFillRandomUniform(
|
||||||
tensor_c.host_view()); // <- Fill matrix C on host with zeros
|
tensor_c.host_view(),
|
||||||
|
1,
|
||||||
|
ElementOutput(7),
|
||||||
|
ElementOutput(-8),
|
||||||
|
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||||
|
|
||||||
cutlass::reference::host::TensorFill(
|
cutlass::reference::host::TensorFill(
|
||||||
tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros
|
tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros
|
||||||
@ -387,8 +390,10 @@ int run(Options &options) {
|
|||||||
for (int k = 0; k < options.index_size; ++k) {
|
for (int k = 0; k < options.index_size; ++k) {
|
||||||
int a_col = tensor_indices.at({k, 0});
|
int a_col = tensor_indices.at({k, 0});
|
||||||
tensor_d_ref.at({i, b_c_d_col}) +=
|
tensor_d_ref.at({i, b_c_d_col}) +=
|
||||||
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}) + beta * tensor_c.at({i, b_c_d_col});
|
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user