diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 523e970c..5cc0fbbd 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -120,9 +120,8 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, # cutlass impl timers.append( - bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), - torch.bfloat16, label, sub_label, cutlass_impl, - "cutlass_i8_i8_bf16_scaled_mm")) + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm")) return timers