[Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954)

Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs
This commit is contained in:
Tyler Michael Smith 2024-05-22 10:10:43 -04:00 committed by GitHub
parent c74c913bfb
commit 8674f9880e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 2 deletions

View File

@ -1,6 +1,8 @@
#include <stddef.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
@ -189,8 +191,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get());
cutlass::Status status = gemm_op(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}

View File

@ -1,5 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <iostream>
#include <sstream>
#include <vector>
@ -178,7 +180,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
size_t workspace_size = gemm_op.get_workspace_size(args);
TORCH_CHECK(workspace_size == 0);
cutlass::Status status = gemm_op.run(args);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, stream);
CUTLASS_CHECK(status);
}
} // namespace

View File

@ -190,3 +190,44 @@ def test_cutlass_subset():
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):
def __init__(self, b, scale_a, scale_b, out_dtype):
super().__init__()
self.b = b
self.scale_a = scale_a
self.scale_b = scale_b
self.out_dtype = out_dtype
def forward(self, a):
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)
def test_cutlass_cuda_graph():
m, n, k = 512, 512, 512
a = to_int8(torch.randn((m, k), device="cuda"))
b = to_int8(torch.randn((n, k), device="cuda").t())
scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10)
scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10)
# Construct a trivial model with a single layer that calls a CUTLASS kernel
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
# Run the model with a cuda graph
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(a)
out.zero_()
g.replay()
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)