From cc466a32903d53d0ceca459b766d74ad668c8f87 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 May 2024 19:34:47 -0700 Subject: [PATCH] [Core][Distributed] support cpu&device in broadcast tensor dict (#4660) [Core][Distributed] support both cpu and device tensor in broadcast tensor dict (#4660) --- tests/distributed/test_comm_ops.py | 7 +++- vllm/distributed/communication_op.py | 56 +++++++++++++++++----------- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index aa9e0537..9a7a1f07 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -77,14 +77,18 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) test_dict = { + # device tensor "a": torch.arange(8, dtype=torch.float32, device="cuda"), - "b": torch.arange(16, dtype=torch.int8, device="cuda"), + # CPU tensor + "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], "e": { "a": 1, "b": 2 }, + # empty tensor + "f": torch.tensor([], dtype=torch.float32, device="cuda"), } if rank == 0: @@ -97,6 +101,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, assert recv_dict["c"] == test_dict["c"] assert recv_dict["d"] == test_dict["d"] assert recv_dict["e"] == test_dict["e"] + assert torch.allclose(recv_dict["f"], test_dict["f"]) @pytest.mark.skipif(torch.cuda.device_count() < 2, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 817bd6d8..80d03129 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -137,7 +137,7 @@ def broadcast_object_list(obj_list: List[Any], return obj_list -TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) def _split_tensor_dict( @@ -152,15 +152,13 @@ def _split_tensor_dict( tensor_list = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): - # Note(youkaichao): currently this only supports broadcasting - # tensors on cuda. In the future, we can add device as a field in - # TensorMetadata to support broadcasting tensors on different - # devices. - assert value.is_cuda, ( - f"Tensor {key}: {value} is not on cuda. Currently we only " - f"support broadcasting tensors on cuda.") - metadata_list.append((key, TensorMetadata(value.dtype, - value.size()))) + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = "cpu" if value.is_cpu else "cuda" + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) tensor_list.append(value) else: metadata_list.append((key, value)) @@ -206,11 +204,19 @@ def broadcast_tensor_dict( if tensor.numel() == 0: # Skip broadcasting empty tensors. continue - async_handles.append( - torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True)) + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) for async_handle in async_handles: async_handle.wait() @@ -226,16 +232,24 @@ def broadcast_tensor_dict( if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, - device="cuda") + device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue - async_handle = torch.distributed.broadcast(tensor, - src=src, - async_op=True, - group=group) - async_handles.append(async_handle) + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) tensor_dict[key] = tensor else: tensor_dict[key] = value