Update gather_scatter_fusion.cu
Correct the reference code in gather/scatter example to put bias add in the correct place.
This commit is contained in:
parent
d6f58b2d14
commit
858c735856
@ -43,7 +43,7 @@
|
||||
// for (int k = 0; k < options.index_size; ++k) {
|
||||
// int a_col = tensor_indices.at({k, 0});
|
||||
// tensor_d_ref.at({i, b_c_d_col}) +=
|
||||
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}) + beta * tensor_c.at({i, b_c_d_col});
|
||||
// alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
|
||||
// }
|
||||
// }
|
||||
//
|
||||
@ -229,8 +229,7 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
// the vector width of math instructions in
|
||||
// epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue, // <- data type for alpha in linear combination function
|
||||
cutlass::epilogue::thread::ScaleType::Nothing>; // <- C
|
||||
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 5;
|
||||
@ -301,8 +300,12 @@ int run(Options &options) {
|
||||
ElementInputA(-8),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_c.host_view()); // <- Fill matrix C on host with zeros
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c.host_view(),
|
||||
1,
|
||||
ElementOutput(7),
|
||||
ElementOutput(-8),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d_scattered.host_view()); // <- fill matrix D on host with zeros
|
||||
@ -387,8 +390,10 @@ int run(Options &options) {
|
||||
for (int k = 0; k < options.index_size; ++k) {
|
||||
int a_col = tensor_indices.at({k, 0});
|
||||
tensor_d_ref.at({i, b_c_d_col}) +=
|
||||
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col}) + beta * tensor_c.at({i, b_c_d_col});
|
||||
alpha * tensor_a.at({i, a_col}) * tensor_b.at({k, b_c_d_col});
|
||||
}
|
||||
|
||||
tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col}));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user