diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 529e75fb..0218295a 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,6 +3,7 @@ import os import pytest import torch +import torch.distributed from vllm.distributed.communication_op import ( # noqa graph_capture, tensor_model_parallel_all_reduce) @@ -68,7 +69,7 @@ def test_pynccl(): @worker_fn_wrapper -def multiple_tp_worker_fn(): +def multiple_allreduce_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 1], backend="gloo"), @@ -92,14 +93,14 @@ def multiple_tp_worker_fn(): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") -def test_pynccl_multiple_tp(): +def test_pynccl_multiple_allreduce(): # this tests pynccl for multiple tp groups, in a standalone way # i.e. call `pynccl_comm.all_reduce` directly - distributed_run(multiple_tp_worker_fn, 4) + distributed_run(multiple_allreduce_worker_fn, 4) @worker_fn_wrapper -def multiple_tp_with_vllm_worker_fn(): +def multiple_allreduce_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) @@ -118,10 +119,10 @@ def multiple_tp_with_vllm_worker_fn(): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") -def test_pynccl_multiple_tp_with_vllm(): +def test_pynccl_multiple_allreduce_with_vllm(): # this tests pynccl for multiple tp groups, together with vllm # i.e. call `tensor_model_parallel_all_reduce` - distributed_run(multiple_tp_with_vllm_worker_fn, 4) + distributed_run(multiple_allreduce_with_vllm_worker_fn, 4) @worker_fn_wrapper @@ -151,6 +152,68 @@ def test_pynccl_with_cudagraph(): distributed_run(worker_fn_with_cudagraph, 2) +@worker_fn_wrapper +def send_recv_worker_fn(): + pynccl_comm = PyNcclCommunicator() + if pynccl_comm.rank == 0: + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + else: + tensor = torch.empty(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + if pynccl_comm.rank == 0: + pynccl_comm.send(tensor) + else: + pynccl_comm.recv(tensor) + result = tensor.mean().cpu().item() + assert result == 1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_send_recv(): + distributed_run(send_recv_worker_fn, 2) + + +@worker_fn_wrapper +def multiple_send_recv_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + groups = [ + torch.distributed.new_group(ranks=[0, 2], backend="gloo"), + torch.distributed.new_group(ranks=[1, 3], backend="gloo") + ] + group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] + pynccl_comm = PyNcclCommunicator(group=group, device=device) + if torch.distributed.get_rank() == 0: + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + elif torch.distributed.get_rank() == 1: + tensor = 2 * torch.ones( + 16, 1024, 1024, dtype=torch.float32, device=device) + else: + tensor = torch.empty(16, + 1024, + 1024, + dtype=torch.float32, + device=device) + with pynccl_comm.change_state(enable=True): + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.send(tensor) + else: + pynccl_comm.recv(tensor) + result = tensor.mean().cpu().item() + if torch.distributed.get_rank() in [0, 2]: + assert result == 1 + else: + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_send_recv(): + distributed_run(multiple_send_recv_worker_fn, 4) + + def test_ncclGetUniqueId(): lib = NCCLLibrary() unique_id = lib.ncclGetUniqueId() diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 937fd4d3..2b38ec47 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.distributed import ProcessGroup -from .parallel_state import (get_cpu_world_group, +from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -54,13 +54,19 @@ def graph_capture(): # graph, we use either custom all-reduce kernel or PyTorch NCCL. # We always prioritize using custom all-reduce kernel but fall back # to PyTorch or pynccl if it is disabled or not supported. - pynccl_comm = get_tp_pynccl_communicator() - if pynccl_comm is None: - maybe_pynccl_context = nullcontext() + tp_pynccl_comm = get_tp_pynccl_communicator() + pp_pynccl_comm = get_pp_pynccl_communicator() + if not tp_pynccl_comm: + maybe_tp_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state( + maybe_tp_pynccl_context = tp_pynccl_comm.change_state( enable=True, stream=torch.cuda.current_stream()) - with maybe_pynccl_context: + if not pp_pynccl_comm: + maybe_pp_pynccl_context = nullcontext() + else: + maybe_pp_pynccl_context = pp_pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_tp_pynccl_context, maybe_pp_pynccl_context: yield graph_capture_context diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 092a0910..f5f1de0c 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -126,6 +126,40 @@ class PyNcclCommunicator: ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + def send(self, + tensor: torch.Tensor, + dst: Optional[int] = None, + stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if dst is None: + dst = (self.rank + 1) % self.world_size + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, + tensor: torch.Tensor, + src: Optional[int] = None, + stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if src is None: + src = (self.rank - 1) % self.world_size + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + @contextmanager def change_state(self, enable: Optional[bool] = None, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 43d85674..3aa3744d 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -151,6 +151,22 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -248,6 +264,16 @@ class NCCLLibrary: datatype, op, comm, stream)) + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d24104e3..0ebd7a15 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -22,6 +22,8 @@ _TP_PYNCCL_COMMUNICATOR = None _TP_CA_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. _PP_DEVICE_GROUP: Optional[ProcessGroup] = None +_PP_CPU_GROUP: Optional[ProcessGroup] = None +_PP_PYNCCL_COMMUNICATOR = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -55,6 +57,11 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def get_pp_pynccl_communicator(): + global _PP_PYNCCL_COMMUNICATOR + return _PP_PYNCCL_COMMUNICATOR + + def get_tp_pynccl_communicator(): global _TP_PYNCCL_COMMUNICATOR return _TP_PYNCCL_COMMUNICATOR @@ -180,10 +187,11 @@ def initialize_model_parallel( _TP_CPU_GROUP = cpu_group from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( - group=_TP_CPU_GROUP, - device=_LOCAL_RANK, - ) + if tensor_model_parallel_size > 1: + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) # Initialize a custom fast all-reduce implementation. if _ENABLE_CUSTOM_ALL_REDUCE: @@ -195,17 +203,26 @@ def initialize_model_parallel( ) # Build the pipeline model-parallel groups. - global _PP_DEVICE_GROUP + global _PP_DEVICE_GROUP, _PP_CPU_GROUP + global _PP_PYNCCL_COMMUNICATOR global _PP_GLOBAL_RANKS assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _PP_DEVICE_GROUP = group + _PP_CPU_GROUP = cpu_group _PP_GLOBAL_RANKS = ranks + if pipeline_model_parallel_size > 1: + _PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_PP_CPU_GROUP, + device=_LOCAL_RANK, + ) + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, @@ -267,6 +284,13 @@ def get_pipeline_model_parallel_group(): return _PP_DEVICE_GROUP +def get_pipeline_model_parallel_cpu_group(): + """Get the pipeline model parallel cpu group the caller rank belongs to.""" + assert _PP_CPU_GROUP is not None, ( + "pipeline model parallel cpu group is not initialized") + return _PP_CPU_GROUP + + def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" return torch.distributed.get_world_size(