[Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030)

Signed-off-by: Hanzhi Zhou <hanzhi713@gmail.com>
This commit is contained in:
Hanzhi Zhou 2024-11-06 23:50:47 -08:00 committed by GitHub
parent d7263a1bb8
commit 6192e9b8fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 218 additions and 260 deletions

View File

@ -5,32 +5,29 @@
#include "custom_all_reduce.cuh" #include "custom_all_reduce.cuh"
// fake pointer type, must match fptr_t type in ops.h // Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using fptr_t = int64_t; using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t)); static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
const std::vector<std::string>& handles, torch::Tensor& rank_data, int64_t rank,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = offsets.size(); int world_size = fake_ipc_ptrs.size();
if (world_size > 8) if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported"); throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0) if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now"); throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size) if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in"); throw std::invalid_argument("invalid rank passed in");
cudaIpcMemHandle_t ipc_handles[8]; vllm::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) { for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
} }
return (fptr_t) new vllm::CustomAllreduce( return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), rank, world_size,
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); full_nvlink);
} }
/** /**
@ -55,26 +52,48 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, /**
cudaStream_t stream) { * Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
switch (out.scalar_type()) { switch (out.scalar_type()) {
case at::ScalarType::Float: { case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()), fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data_ptr()), reinterpret_cast<float*>(out.data_ptr()),
out.numel()); out.numel());
break; break;
} }
case at::ScalarType::Half: { case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()), fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()), out.numel()); reinterpret_cast<half*>(out.data_ptr()), out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: { case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>( fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()), stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel()); reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break; break;
} }
@ -85,57 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
} }
} }
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, cudaMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}
void dispose(fptr_t _fa) { void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa;
} }
int64_t meta_size() { return sizeof(vllm::Signal); } int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr()); TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
}
fa->register_buffer(ipc_ptrs);
} }
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta( // Use vector<int64_t> to represent byte data for python binding compatibility.
fptr_t _fa) { std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
auto options = std::vector<int64_t> bytes(handle.begin(), handle.end());
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); return std::make_tuple(bytes, offsets);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
} }
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, // Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets); std::vector<std::string> bytes;
bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
bytes.emplace_back(handles[i].begin(), handles[i].end());
}
bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets);
} }

View File

@ -285,46 +285,52 @@ class CustomAllreduce {
int world_size_; int world_size_;
bool full_nvlink_; bool full_nvlink_;
// below are device pointers
RankSignals sg_; RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map<void*, RankData*> buffers_; std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_; Signal* self_sg_;
// stores the registered device pointers from all ranks // Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_; RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_; std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers // a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_; std::map<IPC_KEY, char*> ipc_handles_;
/** /**
* meta is a pointer to device metadata and temporary buffer for allreduce. * Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
* *
* There's a total of sizeof(Signal) of prefix before the actual data, * Note: this class does not own any device memory. Any required buffers
* so meta + 1 points to actual temporary buffer. * are passed in from the constructor.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/ */
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t* handles, int rank, int world_size, bool full_nvlink = true)
const std::vector<int64_t>& offsets, int rank,
bool full_nvlink = true)
: rank_(rank), : rank_(rank),
world_size_(offsets.size()), world_size_(world_size),
full_nvlink_(full_nvlink), full_nvlink_(full_nvlink),
self_sg_(meta), self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)), d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
Signal* rank_sg; sg_.signals[i] = signals[i];
if (i != rank_) {
char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_sg = (Signal*)handle;
} else {
rank_sg = self_sg_;
}
sg_.signals[i] = rank_sg;
} }
} }
@ -341,11 +347,10 @@ class CustomAllreduce {
return it->second; return it->second;
} }
std::pair<std::vector<uint8_t>, std::vector<int64_t>> std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size(); auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t); auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::vector<uint8_t> handles(handle_sz * num_buffers, 0); std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers); std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i]; auto ptr = graph_unreg_buffers_[i];
@ -370,26 +375,22 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
} }
void register_buffer(const std::vector<std::string>& handles, /**
const std::vector<int64_t>& offsets, void* self) { * Register already-shared IPC pointers.
*/
void register_buffer(void** ptrs) {
check_rank_data_capacity(); check_rank_data_capacity();
RankData data; RankData data;
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
if (i != rank_) { data.ptrs[i] = ptrs[i];
char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i];
data.ptrs[i] = handle;
} else {
data.ptrs[i] = self;
}
} }
auto d_data = d_rank_data_base_++; auto d_data = d_rank_data_base_++;
CUDACHECK( CUDACHECK(
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
buffers_[self] = d_data; buffers_[ptrs[rank_]] = d_data;
} }
// note: when registering graph buffers, we intentionally choose to not // Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some // deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote // addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example, // possibility of different allocation patterns between ranks. For example,
@ -424,11 +425,13 @@ class CustomAllreduce {
} }
/** /**
* This is the result after careful grid search. Using 36 blocks give the best * Performs allreduce, assuming input has already been registered.
* or close to the best runtime on the devices I tried: A100, A10, A30, T4, *
* V100. You'll notice that NCCL kernels also only take a small amount of SMs. * Block and grid default configs are results after careful grid search. Using
* Not quite sure the underlying reason, but my guess is that too many SMs * 36 blocks give the best or close to the best runtime on the devices I
* will cause contention on NVLink bus. * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/ */
template <typename T> template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size, void allreduce(cudaStream_t stream, T* input, T* output, int size,

View File

@ -135,24 +135,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void* rank_data; void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024; size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0); vllm::Signal* ipc_ptrs[8];
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, for (int i = 0; i < nRanks; i++) {
offsets, myRank); if (i == myRank)
ipc_ptrs[i] = buffer;
else
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i],
cudaIpcMemLazyEnablePeerAccess));
}
vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks);
auto* self_data = auto* self_data =
reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) + reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
std::vector<std::string> handles; void* data[8];
handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
char* begin = (char*)&data_handles[i]; data[i] =
char* end = (char*)&data_handles[i + 1]; ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T);
handles.emplace_back(begin, end);
} }
std::vector<int64_t> offsets(nRanks, fa.register_buffer(data);
sizeof(vllm::Signal) + data_size * sizeof(T));
fa.register_buffer(handles, offsets, self_data);
} }
double* ground_truth; double* ground_truth;

View File

@ -199,20 +199,16 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = int64_t; using fptr_t = int64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
const std::vector<std::string>& handles, torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
const std::vector<int64_t>& offsets, int64_t rank, void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
bool full_nvlink); fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int64_t meta_size(); int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
const std::vector<std::string>& handles, std::tuple<std::vector<int64_t>, std::vector<int64_t>>
const std::vector<int64_t>& offsets); get_graph_buffer_ipc_meta(fptr_t _fa);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta( void register_graph_buffers(fptr_t _fa,
fptr_t _fa); const std::vector<std::vector<int64_t>>& handles,
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets); const std::vector<std::vector<int64_t>>& offsets);
#endif #endif

View File

@ -411,27 +411,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels // Custom all-reduce kernels
custom_ar.def( custom_ar.def(
"init_custom_ar(Tensor meta, Tensor rank_data, " "init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, " "int rank, bool full_nvlink) -> int");
"bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
custom_ar.def( custom_ar.def(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"()"); "int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def("dispose", &dispose); custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size); custom_ar.def("meta_size", &meta_size);
custom_ar.def( custom_ar.def("register_buffer", &register_buffer);
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()");
custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers); custom_ar.def("register_graph_buffers", &register_graph_buffers);
} }

View File

@ -95,13 +95,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
inp = torch.ones(sz, dtype=torch.float32, device=device) inp = torch.ones(sz, dtype=torch.float32, device=device)
out = inp out = inp
for _ in range(num_communication): for _ in range(num_communication):
out = fa.all_reduce_unreg(out) out = fa.all_reduce(out, registered=False)
torch.testing.assert_close(out, inp * (tp_size**num_communication)) torch.testing.assert_close(out, inp * (tp_size**num_communication))
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = inp out = inp
for _ in range(num_communication): for _ in range(num_communication):
out = fa.all_reduce_unreg(out) out = fa.all_reduce(out, registered=False)
torch.testing.assert_close(out, inp * (tp_size**num_communication)) torch.testing.assert_close(out, inp * (tp_size**num_communication))

View File

@ -196,8 +196,8 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def is_cross_device_reduce_2stage(op_name: str): def is_cross_device_reduce_2stage(op_name: str):
return "cross_device_reduce_2stage" in op_name return "cross_device_reduce_2stage" in op_name
def is_custom_ar_all_reduce_unreg(op_name: str): def is_custom_ar_all_reduce(op_name: str):
return "_C_custom_ar::all_reduce_unreg" in op_name return "_C_custom_ar::all_reduce" in op_name
def is_reduce_kernel(op_name: str): def is_reduce_kernel(op_name: str):
return "reduce_kernel" in op_name return "reduce_kernel" in op_name
@ -246,9 +246,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
filter(lambda x: is_cross_device_reduce_2stage(x), ops)) filter(lambda x: is_cross_device_reduce_2stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops)) ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
custom_ar_all_reduce_unreg_ops = list( custom_ar_all_reduce_ops = list(
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops)) filter(lambda x: is_custom_ar_all_reduce(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops)) ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops)) reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops)) ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
@ -289,21 +289,21 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
if len(cross_device_reduce_2stage_ops): if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[ trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1) cross_device_reduce_2stage_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_unreg_ops): if len(custom_ar_all_reduce_ops):
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[ trace_df['custom_ar_all_reduce_ops'] = trace_df[
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1) custom_ar_all_reduce_ops].agg("sum", axis=1)
if len(reduce_kernel_ops): if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1) axis=1)
trace_df.drop( trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + vocab_embed_ops + mem_ops + elementwise_ops +
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops + nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops + cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
reduce_kernel_ops, reduce_kernel_ops,
axis=1, axis=1,
inplace=True) inplace=True)
return trace_df return trace_df

View File

@ -912,20 +912,16 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# custom ar # custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor,
handles: List[str], offsets: List[int], rank: int, rank: int, full_nvlink: bool) -> int:
full_nvlink: bool) -> int: return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, full_nvlink)
offsets, rank, full_nvlink)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes)
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None: def dispose(fa: int) -> None:
@ -936,16 +932,15 @@ def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size() return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, t: torch.Tensor, handles: List[str], def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
offsets: List[int]) -> None: return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: List[str], def register_graph_buffers(fa: int, handles: List[List[int]],
offsets: List[List[int]]) -> None: offsets: List[List[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)

View File

@ -1,6 +1,6 @@
import ctypes import ctypes
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, List, Optional, Union from typing import List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -147,18 +147,14 @@ class CustomAllreduce:
return return
self.disabled = False self.disabled = False
# buffers memory are owned by this Python class and passed to C++ # Buffers memory are owned by this Python class and passed to C++.
# meta data composes of two parts: meta data for synchronization # Meta data composes of two parts: meta data for synchronization and a
# (256 bytes) and a temporary buffer for storing intermediate # temporary buffer for storing intermediate allreduce results.
# allreduce results. self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
self.meta = torch.zeros(ops.meta_size() + max_size, group=group)
dtype=torch.uint8,
device=self.device)
# This is a pre-registered IPC buffer. In eager mode, input tensors # This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed # are first copied into this buffer before allreduce is performed
self.buffer = torch.empty(max_size, self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
dtype=torch.uint8,
device=self.device)
# This is a buffer for storing the tuples of pointers pointing to # This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of # IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB # 8*world_size bytes where world_size is at most 8. Allocating 8MB
@ -170,16 +166,19 @@ class CustomAllreduce:
self.max_size = max_size self.max_size = max_size
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = full_nvlink self.full_nvlink = full_nvlink
self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
offsets, rank, self.full_nvlink) self.full_nvlink)
self.register_buffer(self.buffer) ops.register_buffer(self._ptr, self.buffer_ptrs)
@staticmethod @staticmethod
def create_shared_buffer( def create_shared_buffer(
size_in_bytes: int, size_in_bytes: int,
group: Optional[ProcessGroup] = None) -> List[int]: group: Optional[ProcessGroup] = None) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary() lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes) pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer) handle = lib.cudaIpcGetMemHandle(pointer)
@ -220,60 +219,24 @@ class CustomAllreduce:
if not self.disabled: if not self.disabled:
self.register_graph_buffers() self.register_graph_buffers()
def _get_ipc_meta(self, inp: torch.Tensor): def register_graph_buffers(self):
data = inp.untyped_storage()._share_cuda_() handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handle = data[1] logger.info("Registering %d cuda graph addresses", len(offset))
# https://github.com/pytorch/pytorch/pull/130890 changes # We cannot directly use `dist.all_gather_object` here
# the binary format of the ipc handle # because it is incompatible with `gloo` backend under inference mode.
# it starts from pytorch 2.5 # see https://github.com/pytorch/pytorch/issues/126032 for details.
if len(handle) > 64: all_data = [[None, None]
assert len(handle) == 66 for _ in range(dist.get_world_size(group=self.group))]
# only support SHAREABLE_HANDLE_VERSION = 1 all_data[self.rank] = [handle, offset]
assert int(handle[0]) == 1 ranks = sorted(dist.get_process_group_ranks(group=self.group))
# only support SHAREABLE_CUDA_MALLOC = 'c'
assert handle[1] == ord("c")
handle = handle[2:]
# TODO: support expandable segment
shard_data = (
handle, # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data: List[Optional[Any]] = [[None]
for i in range(self.world_size)]
all_data[self.rank][0] = shard_data
ranks = dist.get_process_group_ranks(group=self.group)
ranks.sort()
for i, rank in enumerate(ranks): for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i], dist.broadcast_object_list(all_data[i],
src=rank, src=rank,
group=self.group, group=self.group,
device="cpu") device="cpu")
# Unpack list of tuples to tuple of lists.
# we cannot directly use `dist.all_gather_object` here handles = [d[0] for d in all_data] # type: ignore
# because it is incompatible with `gloo` backend under inference mode. offsets = [d[1] for d in all_data] # type: ignore
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0][0]) # type: ignore
offsets.append(all_data[i][0][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
ops.register_graph_buffers(self._ptr, handles, offsets) ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor): def should_custom_ar(self, inp: torch.Tensor):
@ -291,45 +254,50 @@ class CustomAllreduce:
return inp_size < self.max_size return inp_size < self.max_size
return False return False
# all reduce, assuming inp tensor is IPC registered with register_buffer, def all_reduce(self,
# or, in the context of cuda graphs, register_graph_buffers inp: torch.Tensor,
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): *,
out: torch.Tensor = None,
registered: bool = False):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.empty_like(inp)
ops.all_reduce_reg(self._ptr, inp, out) if registered:
return out ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
# all reduce, assuming inp tensor is NOT IPC registered ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): self.max_size)
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
# when custom allreduce is disabled, this will be None """The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input): if self.disabled or not self.should_custom_ar(input):
return None return None
if self._IS_CAPTURING: if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
return self.all_reduce_reg(input) return self.all_reduce(input, registered=True)
else: else:
# if warm up, mimic the allocation pattern # If warm up, mimic the allocation pattern since custom
# since custom allreduce is out-of-place # allreduce is out-of-place.
return torch.empty_like(input) return torch.empty_like(input)
else: else:
# note: outside of cuda graph context, # Note: outside of cuda graph context, custom allreduce incurs a
# custom allreduce incurs a cost of cudaMemcpy, which should # cost of cudaMemcpy, which should be small (<=1% of overall
# be small(<=1% of overall latency) compared to the performance # latency) compared to the performance gain of using custom kernels
# gains of using custom kernels return self.all_reduce(input, registered=False)
return self.all_reduce_unreg(input)
return None
def close(self): def close(self):
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
ops.dispose(self._ptr) ops.dispose(self._ptr)
self._ptr = 0 self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
def __del__(self): def __del__(self):
self.close() self.close()