[Core][Distributed] add same-node detection (#5369)

This commit is contained in:
youkaichao 2024-06-11 10:53:59 -07:00 committed by GitHub
parent dcbf4286af
commit c4bd03c7c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 87 additions and 1 deletions

View File

@ -37,6 +37,7 @@ steps:
working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py

View File

@ -0,0 +1,11 @@
import os
import torch
from vllm.distributed.parallel_state import is_in_the_same_node
torch.distributed.init_process_group(backend="gloo")
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}"

View File

@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import (
get_local_rank, get_tensor_model_parallel_cpu_group)
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
from vllm.logger import init_logger
try:
@ -113,6 +113,13 @@ class CustomAllreduce:
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not is_in_the_same_node(group):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:

View File

@ -3,6 +3,8 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
from multiprocessing import resource_tracker, shared_memory
from typing import List, Optional
import torch
@ -376,3 +378,68 @@ def destroy_model_parallel():
_PP_DEVICE_GROUP = None
global _PP_GLOBAL_RANKS
_PP_GLOBAL_RANKS = None
def is_in_the_same_node(pg: ProcessGroup):
"""
This is a collective operation that checks if all processes in the group
are in the same node. It tests if all processes are attached to the same
memory system (shared access to shared memory).
"""
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"is_in_the_same_node should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
magic_message = b"magic_message"
shm = None
try:
with contextlib.suppress(OSError):
if rank == 0:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message
torch.distributed.broadcast_object_list([shm.name],
src=ranks[0],
group=pg)
is_in_the_same_node[0] = 1
else:
# try to open the shared memory segment
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=ranks[0],
group=pg)
name = recv[0]
shm = shared_memory.SharedMemory(name=name)
if shm.buf[:len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1
except Exception as e:
logger.error("Error ignored in is_in_the_same_node: %s", e)
finally:
if shm:
shm.close()
torch.distributed.barrier(group=pg)
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == 0:
if shm:
shm.unlink()
else:
if shm:
# fix to https://stackoverflow.com/q/62748654/9191338
resource_tracker.unregister(
shm._name, "shared_memory") # type: ignore[attr-defined]
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return is_in_the_same_node.sum().item() == world_size