[core][distributed] fix custom allreduce in pytorch 2.5 (#9815)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-10-29 17:06:24 -07:00 committed by GitHub
parent bc73e9821c
commit 1ab6f6b4ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -191,8 +191,20 @@ class CustomAllreduce:
def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
handle = data[1]
# https://github.com/pytorch/pytorch/pull/130890 changes
# the binary format of the ipc handle
# it starts from pytorch 2.5
if len(handle) > 64:
assert len(handle) == 66
# only support SHAREABLE_HANDLE_VERSION = 1
assert int(handle[0]) == 1
# only support SHAREABLE_CUDA_MALLOC = 'c'
assert handle[1] == ord("c")
handle = handle[2:]
# TODO: support expandable segment
shard_data = (
data[1], # ipc handle to base ptr
handle, # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)