[Core][Distributed] support cpu&device in broadcast tensor dict (#4660)
[Core][Distributed] support both cpu and device tensor in broadcast tensor dict (#4660)
This commit is contained in:
parent
8344f7742b
commit
cc466a3290
@ -77,14 +77,18 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
|||||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||||
distributed_init_port)
|
distributed_init_port)
|
||||||
test_dict = {
|
test_dict = {
|
||||||
|
# device tensor
|
||||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||||
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
|
# CPU tensor
|
||||||
|
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||||
"c": "test",
|
"c": "test",
|
||||||
"d": [1, 2, 3],
|
"d": [1, 2, 3],
|
||||||
"e": {
|
"e": {
|
||||||
"a": 1,
|
"a": 1,
|
||||||
"b": 2
|
"b": 2
|
||||||
},
|
},
|
||||||
|
# empty tensor
|
||||||
|
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -97,6 +101,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
|||||||
assert recv_dict["c"] == test_dict["c"]
|
assert recv_dict["c"] == test_dict["c"]
|
||||||
assert recv_dict["d"] == test_dict["d"]
|
assert recv_dict["d"] == test_dict["d"]
|
||||||
assert recv_dict["e"] == test_dict["e"]
|
assert recv_dict["e"] == test_dict["e"]
|
||||||
|
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
|||||||
@ -137,7 +137,7 @@ def broadcast_object_list(obj_list: List[Any],
|
|||||||
return obj_list
|
return obj_list
|
||||||
|
|
||||||
|
|
||||||
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||||
|
|
||||||
|
|
||||||
def _split_tensor_dict(
|
def _split_tensor_dict(
|
||||||
@ -152,15 +152,13 @@ def _split_tensor_dict(
|
|||||||
tensor_list = []
|
tensor_list = []
|
||||||
for key, value in tensor_dict.items():
|
for key, value in tensor_dict.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
# Note(youkaichao): currently this only supports broadcasting
|
# Note: we cannot use `value.device` here,
|
||||||
# tensors on cuda. In the future, we can add device as a field in
|
# because it contains not only the device type but also the device
|
||||||
# TensorMetadata to support broadcasting tensors on different
|
# index (e.g. "cuda:0"). We only need the device type.
|
||||||
# devices.
|
# receiving side will set the device index.
|
||||||
assert value.is_cuda, (
|
device = "cpu" if value.is_cpu else "cuda"
|
||||||
f"Tensor {key}: {value} is not on cuda. Currently we only "
|
metadata_list.append(
|
||||||
f"support broadcasting tensors on cuda.")
|
(key, TensorMetadata(device, value.dtype, value.size())))
|
||||||
metadata_list.append((key, TensorMetadata(value.dtype,
|
|
||||||
value.size())))
|
|
||||||
tensor_list.append(value)
|
tensor_list.append(value)
|
||||||
else:
|
else:
|
||||||
metadata_list.append((key, value))
|
metadata_list.append((key, value))
|
||||||
@ -206,11 +204,19 @@ def broadcast_tensor_dict(
|
|||||||
if tensor.numel() == 0:
|
if tensor.numel() == 0:
|
||||||
# Skip broadcasting empty tensors.
|
# Skip broadcasting empty tensors.
|
||||||
continue
|
continue
|
||||||
async_handles.append(
|
if tensor.is_cpu:
|
||||||
torch.distributed.broadcast(tensor,
|
# use metadata_group for CPU tensors
|
||||||
|
handle = torch.distributed.broadcast(tensor,
|
||||||
|
src=src,
|
||||||
|
group=metadata_group,
|
||||||
|
async_op=True)
|
||||||
|
else:
|
||||||
|
# use group for GPU tensors
|
||||||
|
handle = torch.distributed.broadcast(tensor,
|
||||||
src=src,
|
src=src,
|
||||||
group=group,
|
group=group,
|
||||||
async_op=True))
|
async_op=True)
|
||||||
|
async_handles.append(handle)
|
||||||
for async_handle in async_handles:
|
for async_handle in async_handles:
|
||||||
async_handle.wait()
|
async_handle.wait()
|
||||||
|
|
||||||
@ -226,16 +232,24 @@ def broadcast_tensor_dict(
|
|||||||
if isinstance(value, TensorMetadata):
|
if isinstance(value, TensorMetadata):
|
||||||
tensor = torch.empty(value.size,
|
tensor = torch.empty(value.size,
|
||||||
dtype=value.dtype,
|
dtype=value.dtype,
|
||||||
device="cuda")
|
device=value.device)
|
||||||
if tensor.numel() == 0:
|
if tensor.numel() == 0:
|
||||||
# Skip broadcasting empty tensors.
|
# Skip broadcasting empty tensors.
|
||||||
tensor_dict[key] = tensor
|
tensor_dict[key] = tensor
|
||||||
continue
|
continue
|
||||||
async_handle = torch.distributed.broadcast(tensor,
|
if tensor.is_cpu:
|
||||||
|
# use metadata_group for CPU tensors
|
||||||
|
handle = torch.distributed.broadcast(tensor,
|
||||||
src=src,
|
src=src,
|
||||||
async_op=True,
|
group=metadata_group,
|
||||||
group=group)
|
async_op=True)
|
||||||
async_handles.append(async_handle)
|
else:
|
||||||
|
# use group for GPU tensors
|
||||||
|
handle = torch.distributed.broadcast(tensor,
|
||||||
|
src=src,
|
||||||
|
group=group,
|
||||||
|
async_op=True)
|
||||||
|
async_handles.append(handle)
|
||||||
tensor_dict[key] = tensor
|
tensor_dict[key] = tensor
|
||||||
else:
|
else:
|
||||||
tensor_dict[key] = value
|
tensor_dict[key] = value
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user