[Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030)
Signed-off-by: Hanzhi Zhou <hanzhi713@gmail.com>
This commit is contained in:
parent
d7263a1bb8
commit
6192e9b8fe
@ -5,32 +5,29 @@
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
// fake pointer type, must match fptr_t type in ops.h
|
||||
// Fake pointer type, must match fptr_t type in ops.h.
|
||||
// We use this type alias to indicate when pointers are passed in as int64_t.
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank,
|
||||
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
cudaIpcMemHandle_t ipc_handles[8];
|
||||
vllm::Signal* ipc_ptrs[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
|
||||
rank_data.numel(), rank, world_size,
|
||||
full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -55,26 +52,48 @@ bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
cudaStream_t stream) {
|
||||
/**
|
||||
* Performs an out-of-place allreduce and stores result in out.
|
||||
*
|
||||
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_is_weak_contiguous(inp));
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
reg_buffer = inp.data_ptr();
|
||||
}
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
|
||||
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
@ -85,57 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, cudaMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
delete fa;
|
||||
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
}
|
||||
|
||||
int64_t meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
fa->register_buffer(ipc_ptrs);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handles =
|
||||
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
|
||||
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
|
||||
return {handles, std::move(offsets)};
|
||||
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
||||
return std::make_tuple(bytes, offsets);
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
std::vector<std::string> bytes;
|
||||
bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
@ -285,46 +285,52 @@ class CustomAllreduce {
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||
// capture time. However, the peer pointers are not known during graph capture
|
||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
||||
// that as the argument to the kernel. The kernel arguments are stored in
|
||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
||||
// the IPC handles are exchanged between ranks.
|
||||
//
|
||||
// The overall process looks like this:
|
||||
// 1. Graph capture.
|
||||
// 2. Each rank obtains the IPC handles for each addresses used during cuda
|
||||
// graph capture using get_graph_buffer_ipc_meta.
|
||||
// 3. (In Python) all gather the IPC handles.
|
||||
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
|
||||
// the rank data array at corresponding positions.
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||
* For each of the buffer, the layout is as follows:
|
||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||
* The first section is for allreduce synchronization, and the second section
|
||||
* is for storing the intermediate results required by some allreduce algos.
|
||||
*
|
||||
* There's a total of sizeof(Signal) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
|
||||
const cudaIpcMemHandle_t* handles,
|
||||
const std::vector<int64_t>& offsets, int rank,
|
||||
bool full_nvlink = true)
|
||||
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
||||
int rank, int world_size, bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(meta),
|
||||
self_sg_(signals[rank]),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Signal* rank_sg;
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_sg = (Signal*)handle;
|
||||
} else {
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
sg_.signals[i] = rank_sg;
|
||||
sg_.signals[i] = signals[i];
|
||||
}
|
||||
}
|
||||
|
||||
@ -341,11 +347,10 @@ class CustomAllreduce {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta() {
|
||||
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
@ -370,26 +375,22 @@ class CustomAllreduce {
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, void* self) {
|
||||
/**
|
||||
* Register already-shared IPC pointers.
|
||||
*/
|
||||
void register_buffer(void** ptrs) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
data.ptrs[i] = ptrs[i];
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(
|
||||
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
buffers_[ptrs[rank_]] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
@ -424,11 +425,13 @@ class CustomAllreduce {
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
* Block and grid default configs are results after careful grid search. Using
|
||||
* 36 blocks give the best or close to the best runtime on the devices I
|
||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||
|
||||
@ -135,24 +135,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
|
||||
void* rank_data;
|
||||
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
||||
std::vector<int64_t> offsets(nRanks, 0);
|
||||
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||
offsets, myRank);
|
||||
vllm::Signal* ipc_ptrs[8];
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
if (i == myRank)
|
||||
ipc_ptrs[i] = buffer;
|
||||
else
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i],
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
}
|
||||
vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks);
|
||||
auto* self_data =
|
||||
reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
// hack buffer registration
|
||||
{
|
||||
std::vector<std::string> handles;
|
||||
handles.reserve(nRanks);
|
||||
void* data[8];
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
char* begin = (char*)&data_handles[i];
|
||||
char* end = (char*)&data_handles[i + 1];
|
||||
handles.emplace_back(begin, end);
|
||||
data[i] =
|
||||
((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T);
|
||||
}
|
||||
std::vector<int64_t> offsets(nRanks,
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
fa.register_buffer(handles, offsets, self_data);
|
||||
fa.register_buffer(data);
|
||||
}
|
||||
|
||||
double* ground_truth;
|
||||
|
||||
22
csrc/ops.h
22
csrc/ops.h
@ -199,20 +199,16 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out);
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
void dispose(fptr_t _fa);
|
||||
int64_t meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets);
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
|
||||
@ -411,27 +411,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
// Custom all-reduce kernels
|
||||
custom_ar.def(
|
||||
"init_custom_ar(Tensor meta, Tensor rank_data, "
|
||||
"str[] handles, int[] offsets, int rank, "
|
||||
"bool full_nvlink) -> int");
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool full_nvlink) -> int");
|
||||
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
||||
|
||||
custom_ar.def(
|
||||
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
|
||||
"()");
|
||||
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
custom_ar.def("dispose", &dispose);
|
||||
custom_ar.def("meta_size", &meta_size);
|
||||
|
||||
custom_ar.def(
|
||||
"register_buffer(int fa, Tensor t, str[] handles, "
|
||||
"int[] offsets) -> ()");
|
||||
custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
||||
|
||||
custom_ar.def("register_buffer", ®ister_buffer);
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
}
|
||||
|
||||
@ -95,13 +95,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
out = fa.all_reduce(out, registered=False)
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
out = fa.all_reduce(out, registered=False)
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
|
||||
|
||||
@ -196,8 +196,8 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
def is_cross_device_reduce_2stage(op_name: str):
|
||||
return "cross_device_reduce_2stage" in op_name
|
||||
|
||||
def is_custom_ar_all_reduce_unreg(op_name: str):
|
||||
return "_C_custom_ar::all_reduce_unreg" in op_name
|
||||
def is_custom_ar_all_reduce(op_name: str):
|
||||
return "_C_custom_ar::all_reduce" in op_name
|
||||
|
||||
def is_reduce_kernel(op_name: str):
|
||||
return "reduce_kernel" in op_name
|
||||
@ -246,9 +246,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
|
||||
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
|
||||
|
||||
custom_ar_all_reduce_unreg_ops = list(
|
||||
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
|
||||
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
|
||||
custom_ar_all_reduce_ops = list(
|
||||
filter(lambda x: is_custom_ar_all_reduce(x), ops))
|
||||
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
|
||||
|
||||
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
|
||||
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
|
||||
@ -289,21 +289,21 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
if len(cross_device_reduce_2stage_ops):
|
||||
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
|
||||
cross_device_reduce_2stage_ops].agg("sum", axis=1)
|
||||
if len(custom_ar_all_reduce_unreg_ops):
|
||||
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
|
||||
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
|
||||
if len(custom_ar_all_reduce_ops):
|
||||
trace_df['custom_ar_all_reduce_ops'] = trace_df[
|
||||
custom_ar_all_reduce_ops].agg("sum", axis=1)
|
||||
if len(reduce_kernel_ops):
|
||||
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
|
||||
axis=1)
|
||||
|
||||
trace_df.drop(
|
||||
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
|
||||
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
|
||||
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
|
||||
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
|
||||
reduce_kernel_ops,
|
||||
axis=1,
|
||||
inplace=True)
|
||||
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
|
||||
vocab_embed_ops + mem_ops + elementwise_ops +
|
||||
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
|
||||
nccl_other_ops + cross_device_reduce_1stage_ops +
|
||||
cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
|
||||
reduce_kernel_ops,
|
||||
axis=1,
|
||||
inplace=True)
|
||||
return trace_df
|
||||
|
||||
|
||||
|
||||
@ -912,20 +912,16 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
|
||||
|
||||
|
||||
# custom ar
|
||||
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
|
||||
handles: List[str], offsets: List[int], rank: int,
|
||||
full_nvlink: bool) -> int:
|
||||
return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
|
||||
offsets, rank, full_nvlink)
|
||||
def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor,
|
||||
rank: int, full_nvlink: bool) -> int:
|
||||
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
|
||||
full_nvlink)
|
||||
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
|
||||
|
||||
|
||||
def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int) -> None:
|
||||
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
|
||||
reg_buffer_sz_bytes)
|
||||
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
@ -936,16 +932,15 @@ def meta_size() -> int:
|
||||
return torch.ops._C_custom_ar.meta_size()
|
||||
|
||||
|
||||
def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
|
||||
offsets: List[int]) -> None:
|
||||
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
|
||||
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
|
||||
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
|
||||
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
|
||||
def register_graph_buffers(fa: int, handles: List[str],
|
||||
def register_graph_buffers(fa: int, handles: List[List[int]],
|
||||
offsets: List[List[int]]) -> None:
|
||||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import ctypes
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -147,18 +147,14 @@ class CustomAllreduce:
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# buffers memory are owned by this Python class and passed to C++
|
||||
# meta data composes of two parts: meta data for synchronization
|
||||
# (256 bytes) and a temporary buffer for storing intermediate
|
||||
# allreduce results.
|
||||
self.meta = torch.zeros(ops.meta_size() + max_size,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
|
||||
group=group)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer = torch.empty(max_size,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
@ -170,16 +166,19 @@ class CustomAllreduce:
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
handles, offsets = self._get_ipc_meta(self.meta)
|
||||
self.full_nvlink = full_nvlink
|
||||
self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles,
|
||||
offsets, rank, self.full_nvlink)
|
||||
self.register_buffer(self.buffer)
|
||||
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
|
||||
self.full_nvlink)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None) -> List[int]:
|
||||
"""
|
||||
Creates a shared buffer and returns a list of pointers
|
||||
representing the buffer on all processes in the group.
|
||||
"""
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
@ -220,60 +219,24 @@ class CustomAllreduce:
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
data = inp.untyped_storage()._share_cuda_()
|
||||
handle = data[1]
|
||||
# https://github.com/pytorch/pytorch/pull/130890 changes
|
||||
# the binary format of the ipc handle
|
||||
# it starts from pytorch 2.5
|
||||
if len(handle) > 64:
|
||||
assert len(handle) == 66
|
||||
# only support SHAREABLE_HANDLE_VERSION = 1
|
||||
assert int(handle[0]) == 1
|
||||
# only support SHAREABLE_CUDA_MALLOC = 'c'
|
||||
assert handle[1] == ord("c")
|
||||
handle = handle[2:]
|
||||
# TODO: support expandable segment
|
||||
shard_data = (
|
||||
handle, # ipc handle to base ptr
|
||||
data[3], # offset of base ptr
|
||||
)
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
# Note: don't use `[[None]] * self.world_size` here
|
||||
# because it will create a list of the same reference
|
||||
all_data: List[Optional[Any]] = [[None]
|
||||
for i in range(self.world_size)]
|
||||
all_data[self.rank][0] = shard_data
|
||||
|
||||
ranks = dist.get_process_group_ranks(group=self.group)
|
||||
ranks.sort()
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None]
|
||||
for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(all_data[i],
|
||||
src=rank,
|
||||
group=self.group,
|
||||
device="cpu")
|
||||
|
||||
# we cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0][0]) # type: ignore
|
||||
offsets.append(all_data[i][0][1]) # type: ignore
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
handles, offsets = self._get_ipc_meta(inp)
|
||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
@ -291,45 +254,50 @@ class CustomAllreduce:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
# or, in the context of cuda graphs, register_graph_buffers
|
||||
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
def all_reduce(self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: torch.Tensor = None,
|
||||
registered: bool = False):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.all_reduce_reg(self._ptr, inp, out)
|
||||
return out
|
||||
|
||||
# all reduce, assuming inp tensor is NOT IPC registered
|
||||
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
|
||||
self.max_size)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# when custom allreduce is disabled, this will be None
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce_reg(input)
|
||||
return self.all_reduce(input, registered=True)
|
||||
else:
|
||||
# if warm up, mimic the allocation pattern
|
||||
# since custom allreduce is out-of-place
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
return self.all_reduce_unreg(input)
|
||||
|
||||
return None
|
||||
# Note: outside of cuda graph context, custom allreduce incurs a
|
||||
# cost of cudaMemcpy, which should be small (<=1% of overall
|
||||
# latency) compared to the performance gain of using custom kernels
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user