[VLM][BugFix] Make sure that multi_modal_kwargs can broadcast properly with ring buffer. (#5905)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010 2024-06-28 00:29:13 -07:00 committed by GitHub
parent f136da15e1
commit 74d55c065b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,7 +45,7 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict( def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]], tensor_dict: Dict[str, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts: """Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
@ -473,11 +473,11 @@ class GroupCoordinator:
def broadcast_tensor_dict( def broadcast_tensor_dict(
self, self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary. """Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank. NOTE: `src` is the local rank of the source rank.
""" """
@ -558,9 +558,9 @@ class GroupCoordinator:
def send_tensor_dict( def send_tensor_dict(
self, self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]], tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary. """Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank. NOTE: `dst` is the local rank of the source rank.
""" """
@ -599,7 +599,7 @@ class GroupCoordinator:
def recv_tensor_dict( def recv_tensor_dict(
self, self,
src: Optional[int] = None src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary. """Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank. NOTE: `src` is the local rank of the source rank.
""" """
@ -615,7 +615,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src) recv_metadata_list = self.recv_object(src=src)
tensor_dict = {} tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list: for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, tensor = torch.empty(value.size,
@ -623,7 +623,7 @@ class GroupCoordinator:
device=value.device) device=value.device)
if tensor.numel() == 0: if tensor.numel() == 0:
# Skip broadcasting empty tensors. # Skip broadcasting empty tensors.
tensor_dict[key] = tensor _update_nested_dict(tensor_dict, key, tensor)
continue continue
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
@ -633,9 +633,9 @@ class GroupCoordinator:
else: else:
# use group for GPU tensors # use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group) torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor _update_nested_dict(tensor_dict, key, tensor)
else: else:
tensor_dict[key] = value _update_nested_dict(tensor_dict, key, value)
return tensor_dict return tensor_dict
def barrier(self): def barrier(self):