[Core][Distributed] add same-node detection (#5369)
This commit is contained in:
parent
dcbf4286af
commit
c4bd03c7c5
@ -37,6 +37,7 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
|
|||||||
11
tests/distributed/test_same_node.py
Normal file
11
tests/distributed/test_same_node.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||||
|
|
||||||
|
torch.distributed.init_process_group(backend="gloo")
|
||||||
|
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
|
||||||
|
|
||||||
|
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||||
|
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
||||||
@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||||
gpu_p2p_access_check)
|
gpu_p2p_access_check)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_local_rank, get_tensor_model_parallel_cpu_group)
|
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -113,6 +113,13 @@ class CustomAllreduce:
|
|||||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||||
|
|
||||||
|
if not is_in_the_same_node(group):
|
||||||
|
# No need to initialize custom allreduce for multi-node case.
|
||||||
|
logger.warning(
|
||||||
|
"Custom allreduce is disabled because this process group"
|
||||||
|
" spans across nodes.")
|
||||||
|
return
|
||||||
|
|
||||||
rank = dist.get_rank(group=self.group)
|
rank = dist.get_rank(group=self.group)
|
||||||
world_size = dist.get_world_size(group=self.group)
|
world_size = dist.get_world_size(group=self.group)
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
|
|||||||
@ -3,6 +3,8 @@
|
|||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
"""Tensor and pipeline parallel groups."""
|
"""Tensor and pipeline parallel groups."""
|
||||||
|
import contextlib
|
||||||
|
from multiprocessing import resource_tracker, shared_memory
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -376,3 +378,68 @@ def destroy_model_parallel():
|
|||||||
_PP_DEVICE_GROUP = None
|
_PP_DEVICE_GROUP = None
|
||||||
global _PP_GLOBAL_RANKS
|
global _PP_GLOBAL_RANKS
|
||||||
_PP_GLOBAL_RANKS = None
|
_PP_GLOBAL_RANKS = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_in_the_same_node(pg: ProcessGroup):
|
||||||
|
"""
|
||||||
|
This is a collective operation that checks if all processes in the group
|
||||||
|
are in the same node. It tests if all processes are attached to the same
|
||||||
|
memory system (shared access to shared memory).
|
||||||
|
"""
|
||||||
|
assert torch.distributed.get_backend(
|
||||||
|
pg) != torch.distributed.Backend.NCCL, (
|
||||||
|
"is_in_the_same_node should be tested with a non-NCCL group.")
|
||||||
|
# local rank inside the group
|
||||||
|
rank = torch.distributed.get_rank(group=pg)
|
||||||
|
world_size = torch.distributed.get_world_size(group=pg)
|
||||||
|
|
||||||
|
# local tensor in each process to store the result
|
||||||
|
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
||||||
|
|
||||||
|
# global ranks of the processes in the group
|
||||||
|
ranks = torch.distributed.get_process_group_ranks(pg)
|
||||||
|
|
||||||
|
magic_message = b"magic_message"
|
||||||
|
shm = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
if rank == 0:
|
||||||
|
# create a shared memory segment
|
||||||
|
shm = shared_memory.SharedMemory(create=True, size=128)
|
||||||
|
shm.buf[:len(magic_message)] = magic_message
|
||||||
|
torch.distributed.broadcast_object_list([shm.name],
|
||||||
|
src=ranks[0],
|
||||||
|
group=pg)
|
||||||
|
is_in_the_same_node[0] = 1
|
||||||
|
else:
|
||||||
|
# try to open the shared memory segment
|
||||||
|
recv = [None]
|
||||||
|
torch.distributed.broadcast_object_list(recv,
|
||||||
|
src=ranks[0],
|
||||||
|
group=pg)
|
||||||
|
name = recv[0]
|
||||||
|
shm = shared_memory.SharedMemory(name=name)
|
||||||
|
if shm.buf[:len(magic_message)] == magic_message:
|
||||||
|
is_in_the_same_node[rank] = 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error ignored in is_in_the_same_node: %s", e)
|
||||||
|
finally:
|
||||||
|
if shm:
|
||||||
|
shm.close()
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=pg)
|
||||||
|
|
||||||
|
# clean up the shared memory segment
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
if rank == 0:
|
||||||
|
if shm:
|
||||||
|
shm.unlink()
|
||||||
|
else:
|
||||||
|
if shm:
|
||||||
|
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||||
|
resource_tracker.unregister(
|
||||||
|
shm._name, "shared_memory") # type: ignore[attr-defined]
|
||||||
|
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||||
|
|
||||||
|
return is_in_the_same_node.sum().item() == world_size
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user