diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b48ef31b..6b12d19b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -37,6 +37,7 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 2 commands: + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py new file mode 100644 index 00000000..4880bab7 --- /dev/null +++ b/tests/distributed/test_same_node.py @@ -0,0 +1,11 @@ +import os + +import torch + +from vllm.distributed.parallel_state import is_in_the_same_node + +torch.distributed.init_process_group(backend="gloo") +test_result = is_in_the_same_node(torch.distributed.group.WORLD) + +expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" +assert test_result == expected, f"Expected {expected}, got {test_result}" diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 4a0e19bc..bbc2284f 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -10,7 +10,7 @@ from vllm import _custom_ops as ops from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import ( - get_local_rank, get_tensor_model_parallel_cpu_group) + get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node) from vllm.logger import init_logger try: @@ -113,6 +113,13 @@ class CustomAllreduce: assert dist.get_backend(group) != dist.Backend.NCCL, ( "CustomAllreduce should be attached to a non-NCCL group.") + if not is_in_the_same_node(group): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes.") + return + rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) if world_size == 1: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0ebd7a15..b6d1eeff 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,6 +3,8 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" +import contextlib +from multiprocessing import resource_tracker, shared_memory from typing import List, Optional import torch @@ -376,3 +378,68 @@ def destroy_model_parallel(): _PP_DEVICE_GROUP = None global _PP_GLOBAL_RANKS _PP_GLOBAL_RANKS = None + + +def is_in_the_same_node(pg: ProcessGroup): + """ + This is a collective operation that checks if all processes in the group + are in the same node. It tests if all processes are attached to the same + memory system (shared access to shared memory). + """ + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "is_in_the_same_node should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == 0: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + torch.distributed.broadcast_object_list([shm.name], + src=ranks[0], + group=pg) + is_in_the_same_node[0] = 1 + else: + # try to open the shared memory segment + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=ranks[0], + group=pg) + name = recv[0] + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + torch.distributed.barrier(group=pg) + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == 0: + if shm: + shm.unlink() + else: + if shm: + # fix to https://stackoverflow.com/q/62748654/9191338 + resource_tracker.unregister( + shm._name, "shared_memory") # type: ignore[attr-defined] + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + + return is_in_the_same_node.sum().item() == world_size