[VLM][Bugfix] Make sure that multi_modal_kwargs is broadcasted properly (#5880)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
This commit is contained in:
parent
6eabc6cb0e
commit
d12af207d2
@ -27,7 +27,9 @@ steps:
|
|||||||
|
|
||||||
- label: Core Test
|
- label: Core Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
command: pytest -v -s core
|
commands:
|
||||||
|
- pytest -v -s core
|
||||||
|
- pytest -v -s distributed/test_parallel_state.py
|
||||||
|
|
||||||
- label: Distributed Comm Ops Test
|
- label: Distributed Comm Ops Test
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
|
|||||||
49
tests/distributed/test_parallel_state.py
Normal file
49
tests/distributed/test_parallel_state.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import (_split_tensor_dict,
|
||||||
|
_update_nested_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_tensor_dict():
|
||||||
|
test_dict = {
|
||||||
|
"key_a": "a",
|
||||||
|
"key_b": torch.arange(8, dtype=torch.float32),
|
||||||
|
"key_c": {
|
||||||
|
"key_1": torch.arange(5, dtype=torch.float32),
|
||||||
|
"key_2": torch.tensor([], dtype=torch.float32),
|
||||||
|
"key_3": 123,
|
||||||
|
},
|
||||||
|
"key_d": {},
|
||||||
|
}
|
||||||
|
metadata_list, tensor_list = _split_tensor_dict(test_dict)
|
||||||
|
assert len(metadata_list) == 6
|
||||||
|
assert torch.allclose(tensor_list[0], test_dict["key_b"])
|
||||||
|
assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"])
|
||||||
|
assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_nested_dict():
|
||||||
|
flattened_keys_values = [("key1%key2%key3", "value1"),
|
||||||
|
("key1%key2%key4", "value2"),
|
||||||
|
("key1%key5", "value3"), ("key6%key7", "value4"),
|
||||||
|
("key8", "value5")]
|
||||||
|
res: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Update the nested dictionary with each flattened key-value pair
|
||||||
|
for flat_key, value in flattened_keys_values:
|
||||||
|
_update_nested_dict(res, flat_key, value)
|
||||||
|
assert res == {
|
||||||
|
"key1": {
|
||||||
|
"key2": {
|
||||||
|
"key3": "value1",
|
||||||
|
"key4": "value2"
|
||||||
|
},
|
||||||
|
"key5": "value3"
|
||||||
|
},
|
||||||
|
"key6": {
|
||||||
|
"key7": "value4"
|
||||||
|
},
|
||||||
|
"key8": "value5"
|
||||||
|
}
|
||||||
@ -45,14 +45,17 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
|||||||
|
|
||||||
|
|
||||||
def _split_tensor_dict(
|
def _split_tensor_dict(
|
||||||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
|
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
|
||||||
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||||
"""Split the tensor dictionary into two parts:
|
"""Split the tensor dictionary into two parts:
|
||||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||||
by its metadata.
|
by its metadata.
|
||||||
2. A list of tensors.
|
2. A list of tensors.
|
||||||
|
|
||||||
|
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
|
||||||
|
metadata will be "key1%key2".
|
||||||
"""
|
"""
|
||||||
metadata_list = []
|
metadata_list: List[Tuple[str, Any]] = []
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for key, value in tensor_dict.items():
|
for key, value in tensor_dict.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
@ -62,13 +65,31 @@ def _split_tensor_dict(
|
|||||||
# receiving side will set the device index.
|
# receiving side will set the device index.
|
||||||
device = value.device.type
|
device = value.device.type
|
||||||
metadata_list.append(
|
metadata_list.append(
|
||||||
(key, TensorMetadata(device, value.dtype, value.size())))
|
(prefix + key, TensorMetadata(device, value.dtype,
|
||||||
|
value.size())))
|
||||||
tensor_list.append(value)
|
tensor_list.append(value)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
if len(value) == 0:
|
||||||
|
metadata_list.append((prefix + key, value))
|
||||||
|
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
|
||||||
|
value, prefix + key + "%")
|
||||||
|
metadata_list.extend(inner_metadata_list)
|
||||||
|
tensor_list.extend(inner_tensor_list)
|
||||||
else:
|
else:
|
||||||
metadata_list.append((key, value))
|
metadata_list.append((prefix + key, value))
|
||||||
return metadata_list, tensor_list
|
return metadata_list, tensor_list
|
||||||
|
|
||||||
|
|
||||||
|
def _update_nested_dict(nested_dict, flattened_key, value):
|
||||||
|
key_splits = flattened_key.split("%")
|
||||||
|
cur_dict = nested_dict
|
||||||
|
for k in key_splits[:-1]:
|
||||||
|
if k not in cur_dict:
|
||||||
|
cur_dict[k] = {}
|
||||||
|
cur_dict = cur_dict[k]
|
||||||
|
cur_dict[key_splits[-1]] = value
|
||||||
|
|
||||||
|
|
||||||
class GroupCoordinator:
|
class GroupCoordinator:
|
||||||
"""
|
"""
|
||||||
PyTorch ProcessGroup wrapper for a group of processes.
|
PyTorch ProcessGroup wrapper for a group of processes.
|
||||||
@ -512,7 +533,7 @@ class GroupCoordinator:
|
|||||||
device=value.device)
|
device=value.device)
|
||||||
if tensor.numel() == 0:
|
if tensor.numel() == 0:
|
||||||
# Skip broadcasting empty tensors.
|
# Skip broadcasting empty tensors.
|
||||||
tensor_dict[key] = tensor
|
_update_nested_dict(tensor_dict, key, tensor)
|
||||||
continue
|
continue
|
||||||
if tensor.is_cpu:
|
if tensor.is_cpu:
|
||||||
# use metadata_group for CPU tensors
|
# use metadata_group for CPU tensors
|
||||||
@ -528,9 +549,9 @@ class GroupCoordinator:
|
|||||||
group=group,
|
group=group,
|
||||||
async_op=True)
|
async_op=True)
|
||||||
async_handles.append(handle)
|
async_handles.append(handle)
|
||||||
tensor_dict[key] = tensor
|
_update_nested_dict(tensor_dict, key, tensor)
|
||||||
else:
|
else:
|
||||||
tensor_dict[key] = value
|
_update_nested_dict(tensor_dict, key, value)
|
||||||
for async_handle in async_handles:
|
for async_handle in async_handles:
|
||||||
async_handle.wait()
|
async_handle.wait()
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user