[Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954)
Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs
This commit is contained in:
parent
c74c913bfb
commit
8674f9880e
@ -1,6 +1,8 @@
|
|||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
// clang-format will break include orders
|
// clang-format will break include orders
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "cute/tensor.hpp"
|
#include "cute/tensor.hpp"
|
||||||
@ -189,8 +191,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
|
||||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
cutlass::Status status = gemm_op(args, workspace.get());
|
cutlass::Status status = gemm_op(args, workspace.get(), stream);
|
||||||
CUTLASS_CHECK(status);
|
CUTLASS_CHECK(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -178,7 +180,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
TORCH_CHECK(workspace_size == 0);
|
TORCH_CHECK(workspace_size == 0);
|
||||||
|
|
||||||
cutlass::Status status = gemm_op.run(args);
|
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
cutlass::Status status = gemm_op.run(args, stream);
|
||||||
CUTLASS_CHECK(status);
|
CUTLASS_CHECK(status);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
@ -190,3 +190,44 @@ def test_cutlass_subset():
|
|||||||
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
|
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
|
|
||||||
|
|
||||||
|
# Test to make sure cuda graphs work
|
||||||
|
class CutlassLayer(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, b, scale_a, scale_b, out_dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.b = b
|
||||||
|
self.scale_a = scale_a
|
||||||
|
self.scale_b = scale_b
|
||||||
|
self.out_dtype = out_dtype
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b,
|
||||||
|
self.out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cutlass_cuda_graph():
|
||||||
|
m, n, k = 512, 512, 512
|
||||||
|
|
||||||
|
a = to_int8(torch.randn((m, k), device="cuda"))
|
||||||
|
b = to_int8(torch.randn((n, k), device="cuda").t())
|
||||||
|
|
||||||
|
scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10)
|
||||||
|
scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10)
|
||||||
|
|
||||||
|
# Construct a trivial model with a single layer that calls a CUTLASS kernel
|
||||||
|
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
|
||||||
|
|
||||||
|
# Run the model with a cuda graph
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
out = model(a)
|
||||||
|
out.zero_()
|
||||||
|
g.replay()
|
||||||
|
|
||||||
|
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||||
|
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
||||||
|
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user