diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index e62fe731..3a6b8a22 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -1,6 +1,8 @@ #include #include +#include + // clang-format will break include orders // clang-format off #include "cute/tensor.hpp" @@ -189,8 +191,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, size_t workspace_size = gemm_op.get_workspace_size(args); cutlass::device_memory::allocation workspace(workspace_size); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get()); + cutlass::Status status = gemm_op(args, workspace.get(), stream); CUTLASS_CHECK(status); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 12efcac7..5fd6d8ff 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -178,7 +180,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, size_t workspace_size = gemm_op.get_workspace_size(args); TORCH_CHECK(workspace_size == 0); - cutlass::Status status = gemm_op.run(args); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + cutlass::Status status = gemm_op.run(args, stream); CUTLASS_CHECK(status); } } // namespace diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index fdfd1dee..2cf0e86e 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -190,3 +190,44 @@ def test_cutlass_subset(): b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + + +# Test to make sure cuda graphs work +class CutlassLayer(torch.nn.Module): + + def __init__(self, b, scale_a, scale_b, out_dtype): + super().__init__() + self.b = b + self.scale_a = scale_a + self.scale_b = scale_b + self.out_dtype = out_dtype + + def forward(self, a): + return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b, + self.out_dtype) + + +def test_cutlass_cuda_graph(): + m, n, k = 512, 512, 512 + + a = to_int8(torch.randn((m, k), device="cuda")) + b = to_int8(torch.randn((n, k), device="cuda").t()) + + scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10) + scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10) + + # Construct a trivial model with a single layer that calls a CUTLASS kernel + model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = model(a) + out.zero_() + g.replay() + + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)