Fix flops calculation and tensor b stride calculation in the example 36 (#1278)
* Fix flops calculation and tensor b stride calculation in the example 36 * Fix datatype * Update gather_scatter_fusion.cu
This commit is contained in:
parent
74d1f3e63a
commit
acba5beee5
@ -174,7 +174,7 @@ struct Options {
|
||||
double gflops(double runtime_s) const {
|
||||
|
||||
// Number of real-valued multiply-adds
|
||||
int64_t fmas = problem_size.product();
|
||||
int64_t fmas = problem_size.m() * int64_t(index_size) * problem_size.k();
|
||||
|
||||
// Two flops per multiply-add
|
||||
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
||||
@ -349,7 +349,7 @@ int run(Options &options) {
|
||||
tensor_c.device_data(), // <- reference to matrix C on device
|
||||
tensor_d_scattered.device_data(), // <- reference to matrix D on device
|
||||
tensor_a.layout().capacity(problem_size.mk()),
|
||||
tensor_b.layout().capacity(cutlass::make_Coord(options.index_size, problem_size.n())),
|
||||
tensor_b.layout().capacity(cutlass::make_Coord(options.index_size, problem_size.k())),
|
||||
tensor_c.layout().capacity(problem_size.mn()),
|
||||
tensor_d_scattered.layout().capacity(problem_size.mn()),
|
||||
tensor_a.layout().stride(),
|
||||
|
Loading…
Reference in New Issue
Block a user