[VLM][BugFix] Make sure that multi_modal_kwargs can broadcast properly with ring buffer. (#5905)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
f136da15e1
commit
74d55c065b
@ -45,7 +45,7 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
|||||||
|
|
||||||
|
|
||||||
def _split_tensor_dict(
|
def _split_tensor_dict(
|
||||||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||||
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||||
"""Split the tensor dictionary into two parts:
|
"""Split the tensor dictionary into two parts:
|
||||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||||
@ -473,11 +473,11 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def broadcast_tensor_dict(
|
def broadcast_tensor_dict(
|
||||||
self,
|
self,
|
||||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
||||||
src: int = 0,
|
src: int = 0,
|
||||||
group: Optional[ProcessGroup] = None,
|
group: Optional[ProcessGroup] = None,
|
||||||
metadata_group: Optional[ProcessGroup] = None
|
metadata_group: Optional[ProcessGroup] = None
|
||||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||||
"""Broadcast the input tensor dictionary.
|
"""Broadcast the input tensor dictionary.
|
||||||
NOTE: `src` is the local rank of the source rank.
|
NOTE: `src` is the local rank of the source rank.
|
||||||
"""
|
"""
|
||||||
@ -558,9 +558,9 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
def send_tensor_dict(
|
def send_tensor_dict(
|
||||||
self,
|
self,
|
||||||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||||
dst: Optional[int] = None
|
dst: Optional[int] = None
|
||||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||||
"""Send the input tensor dictionary.
|
"""Send the input tensor dictionary.
|
||||||
NOTE: `dst` is the local rank of the source rank.
|
NOTE: `dst` is the local rank of the source rank.
|
||||||
"""
|
"""
|
||||||
@ -599,7 +599,7 @@ class GroupCoordinator:
|
|||||||
def recv_tensor_dict(
|
def recv_tensor_dict(
|
||||||
self,
|
self,
|
||||||
src: Optional[int] = None
|
src: Optional[int] = None
|
||||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||||
"""Recv the input tensor dictionary.
|
"""Recv the input tensor dictionary.
|
||||||
NOTE: `src` is the local rank of the source rank.
|
NOTE: `src` is the local rank of the source rank.
|
||||||
"""
|
"""
|
||||||
@ -615,7 +615,7 @@ class GroupCoordinator:
|
|||||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
recv_metadata_list = self.recv_object(src=src)
|
recv_metadata_list = self.recv_object(src=src)
|
||||||
tensor_dict = {}
|
tensor_dict: Dict[str, Any] = {}
|
||||||
for key, value in recv_metadata_list:
|
for key, value in recv_metadata_list:
|
||||||
if isinstance(value, TensorMetadata):
|
if isinstance(value, TensorMetadata):
|
||||||
tensor = torch.empty(value.size,
|
tensor = torch.empty(value.size,
|
||||||
@ -623,7 +623,7 @@ class GroupCoordinator:
|
|||||||
device=value.device)
|
device=value.device)
|
||||||
if tensor.numel() == 0:
|
if tensor.numel() == 0:
|
||||||
# Skip broadcasting empty tensors.
|
# Skip broadcasting empty tensors.
|
||||||
tensor_dict[key] = tensor
|
_update_nested_dict(tensor_dict, key, tensor)
|
||||||
continue
|
continue
|
||||||
if tensor.is_cpu:
|
if tensor.is_cpu:
|
||||||
# use metadata_group for CPU tensors
|
# use metadata_group for CPU tensors
|
||||||
@ -633,9 +633,9 @@ class GroupCoordinator:
|
|||||||
else:
|
else:
|
||||||
# use group for GPU tensors
|
# use group for GPU tensors
|
||||||
torch.distributed.recv(tensor, src=src, group=group)
|
torch.distributed.recv(tensor, src=src, group=group)
|
||||||
tensor_dict[key] = tensor
|
_update_nested_dict(tensor_dict, key, tensor)
|
||||||
else:
|
else:
|
||||||
tensor_dict[key] = value
|
_update_nested_dict(tensor_dict, key, value)
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
|
|
||||||
def barrier(self):
|
def barrier(self):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user