[Core][Distributed] support cpu&device in broadcast tensor dict (#4660)
[Core][Distributed] support both cpu and device tensor in broadcast tensor dict (#4660)
This commit is contained in:
parent
8344f7742b
commit
cc466a3290
@ -77,14 +77,18 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
test_dict = {
|
||||
# device tensor
|
||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
|
||||
# CPU tensor
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||
"c": "test",
|
||||
"d": [1, 2, 3],
|
||||
"e": {
|
||||
"a": 1,
|
||||
"b": 2
|
||||
},
|
||||
# empty tensor
|
||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||
}
|
||||
|
||||
if rank == 0:
|
||||
@ -97,6 +101,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
assert recv_dict["c"] == test_dict["c"]
|
||||
assert recv_dict["d"] == test_dict["d"]
|
||||
assert recv_dict["e"] == test_dict["e"]
|
||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
|
||||
@ -137,7 +137,7 @@ def broadcast_object_list(obj_list: List[Any],
|
||||
return obj_list
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
|
||||
|
||||
def _split_tensor_dict(
|
||||
@ -152,15 +152,13 @@ def _split_tensor_dict(
|
||||
tensor_list = []
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Note(youkaichao): currently this only supports broadcasting
|
||||
# tensors on cuda. In the future, we can add device as a field in
|
||||
# TensorMetadata to support broadcasting tensors on different
|
||||
# devices.
|
||||
assert value.is_cuda, (
|
||||
f"Tensor {key}: {value} is not on cuda. Currently we only "
|
||||
f"support broadcasting tensors on cuda.")
|
||||
metadata_list.append((key, TensorMetadata(value.dtype,
|
||||
value.size())))
|
||||
# Note: we cannot use `value.device` here,
|
||||
# because it contains not only the device type but also the device
|
||||
# index (e.g. "cuda:0"). We only need the device type.
|
||||
# receiving side will set the device index.
|
||||
device = "cpu" if value.is_cpu else "cuda"
|
||||
metadata_list.append(
|
||||
(key, TensorMetadata(device, value.dtype, value.size())))
|
||||
tensor_list.append(value)
|
||||
else:
|
||||
metadata_list.append((key, value))
|
||||
@ -206,11 +204,19 @@ def broadcast_tensor_dict(
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
continue
|
||||
async_handles.append(
|
||||
torch.distributed.broadcast(tensor,
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
group=metadata_group,
|
||||
async_op=True)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
group=group,
|
||||
async_op=True))
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
|
||||
@ -226,16 +232,24 @@ def broadcast_tensor_dict(
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device="cuda")
|
||||
device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
async_handle = torch.distributed.broadcast(tensor,
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
async_op=True,
|
||||
group=group)
|
||||
async_handles.append(async_handle)
|
||||
group=metadata_group,
|
||||
async_op=True)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
group=group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
|
||||
Loading…
Reference in New Issue
Block a user