[Core][Distributed] support cpu&device in broadcast tensor dict (#4660)

[Core][Distributed] support both cpu and device tensor in broadcast tensor dict (#4660)
This commit is contained in:
youkaichao 2024-05-07 19:34:47 -07:00 committed by GitHub
parent 8344f7742b
commit cc466a3290
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 22 deletions

View File

@ -77,14 +77,18 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}
if rank == 0:
@ -97,6 +101,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
@pytest.mark.skipif(torch.cuda.device_count() < 2,

View File

@ -137,7 +137,7 @@ def broadcast_object_list(obj_list: List[Any],
return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
@ -152,15 +152,13 @@ def _split_tensor_dict(
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note(youkaichao): currently this only supports broadcasting
# tensors on cuda. In the future, we can add device as a field in
# TensorMetadata to support broadcasting tensors on different
# devices.
assert value.is_cuda, (
f"Tensor {key}: {value} is not on cuda. Currently we only "
f"support broadcasting tensors on cuda.")
metadata_list.append((key, TensorMetadata(value.dtype,
value.size())))
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = "cpu" if value.is_cpu else "cuda"
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
@ -206,11 +204,19 @@ def broadcast_tensor_dict(
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
@ -226,16 +232,24 @@ def broadcast_tensor_dict(
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device="cuda")
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
async_handle = torch.distributed.broadcast(tensor,
src=src,
async_op=True,
group=group)
async_handles.append(async_handle)
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value