[core][distributed] fix custom allreduce in pytorch 2.5 (#9815)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
bc73e9821c
commit
1ab6f6b4ad
@ -191,8 +191,20 @@ class CustomAllreduce:
|
||||
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
data = inp.untyped_storage()._share_cuda_()
|
||||
handle = data[1]
|
||||
# https://github.com/pytorch/pytorch/pull/130890 changes
|
||||
# the binary format of the ipc handle
|
||||
# it starts from pytorch 2.5
|
||||
if len(handle) > 64:
|
||||
assert len(handle) == 66
|
||||
# only support SHAREABLE_HANDLE_VERSION = 1
|
||||
assert int(handle[0]) == 1
|
||||
# only support SHAREABLE_CUDA_MALLOC = 'c'
|
||||
assert handle[1] == ord("c")
|
||||
handle = handle[2:]
|
||||
# TODO: support expandable segment
|
||||
shard_data = (
|
||||
data[1], # ipc handle to base ptr
|
||||
handle, # ipc handle to base ptr
|
||||
data[3], # offset of base ptr
|
||||
)
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user