[Core][Distributed] add same-node detection (#5369)
This commit is contained in:
parent
dcbf4286af
commit
c4bd03c7c5
@ -37,6 +37,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
|
||||
11
tests/distributed/test_same_node.py
Normal file
11
tests/distributed/test_same_node.py
Normal file
@ -0,0 +1,11 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||
|
||||
torch.distributed.init_process_group(backend="gloo")
|
||||
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
|
||||
|
||||
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
||||
@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_local_rank, get_tensor_model_parallel_cpu_group)
|
||||
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
@ -113,6 +113,13 @@ class CustomAllreduce:
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||
|
||||
if not is_in_the_same_node(group):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes.")
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""Tensor and pipeline parallel groups."""
|
||||
import contextlib
|
||||
from multiprocessing import resource_tracker, shared_memory
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
@ -376,3 +378,68 @@ def destroy_model_parallel():
|
||||
_PP_DEVICE_GROUP = None
|
||||
global _PP_GLOBAL_RANKS
|
||||
_PP_GLOBAL_RANKS = None
|
||||
|
||||
|
||||
def is_in_the_same_node(pg: ProcessGroup):
|
||||
"""
|
||||
This is a collective operation that checks if all processes in the group
|
||||
are in the same node. It tests if all processes are attached to the same
|
||||
memory system (shared access to shared memory).
|
||||
"""
|
||||
assert torch.distributed.get_backend(
|
||||
pg) != torch.distributed.Backend.NCCL, (
|
||||
"is_in_the_same_node should be tested with a non-NCCL group.")
|
||||
# local rank inside the group
|
||||
rank = torch.distributed.get_rank(group=pg)
|
||||
world_size = torch.distributed.get_world_size(group=pg)
|
||||
|
||||
# local tensor in each process to store the result
|
||||
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
||||
|
||||
# global ranks of the processes in the group
|
||||
ranks = torch.distributed.get_process_group_ranks(pg)
|
||||
|
||||
magic_message = b"magic_message"
|
||||
shm = None
|
||||
|
||||
try:
|
||||
with contextlib.suppress(OSError):
|
||||
if rank == 0:
|
||||
# create a shared memory segment
|
||||
shm = shared_memory.SharedMemory(create=True, size=128)
|
||||
shm.buf[:len(magic_message)] = magic_message
|
||||
torch.distributed.broadcast_object_list([shm.name],
|
||||
src=ranks[0],
|
||||
group=pg)
|
||||
is_in_the_same_node[0] = 1
|
||||
else:
|
||||
# try to open the shared memory segment
|
||||
recv = [None]
|
||||
torch.distributed.broadcast_object_list(recv,
|
||||
src=ranks[0],
|
||||
group=pg)
|
||||
name = recv[0]
|
||||
shm = shared_memory.SharedMemory(name=name)
|
||||
if shm.buf[:len(magic_message)] == magic_message:
|
||||
is_in_the_same_node[rank] = 1
|
||||
except Exception as e:
|
||||
logger.error("Error ignored in is_in_the_same_node: %s", e)
|
||||
finally:
|
||||
if shm:
|
||||
shm.close()
|
||||
|
||||
torch.distributed.barrier(group=pg)
|
||||
|
||||
# clean up the shared memory segment
|
||||
with contextlib.suppress(OSError):
|
||||
if rank == 0:
|
||||
if shm:
|
||||
shm.unlink()
|
||||
else:
|
||||
if shm:
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
resource_tracker.unregister(
|
||||
shm._name, "shared_memory") # type: ignore[attr-defined]
|
||||
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||
|
||||
return is_in_the_same_node.sum().item() == world_size
|
||||
|
||||
Loading…
Reference in New Issue
Block a user