[core][distributed] fix custom allreduce in pytorch 2.5 (#9815)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
bc73e9821c
commit
1ab6f6b4ad
@ -191,8 +191,20 @@ class CustomAllreduce:
|
|||||||
|
|
||||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||||
data = inp.untyped_storage()._share_cuda_()
|
data = inp.untyped_storage()._share_cuda_()
|
||||||
|
handle = data[1]
|
||||||
|
# https://github.com/pytorch/pytorch/pull/130890 changes
|
||||||
|
# the binary format of the ipc handle
|
||||||
|
# it starts from pytorch 2.5
|
||||||
|
if len(handle) > 64:
|
||||||
|
assert len(handle) == 66
|
||||||
|
# only support SHAREABLE_HANDLE_VERSION = 1
|
||||||
|
assert int(handle[0]) == 1
|
||||||
|
# only support SHAREABLE_CUDA_MALLOC = 'c'
|
||||||
|
assert handle[1] == ord("c")
|
||||||
|
handle = handle[2:]
|
||||||
|
# TODO: support expandable segment
|
||||||
shard_data = (
|
shard_data = (
|
||||||
data[1], # ipc handle to base ptr
|
handle, # ipc handle to base ptr
|
||||||
data[3], # offset of base ptr
|
data[3], # offset of base ptr
|
||||||
)
|
)
|
||||||
return self._gather_ipc_meta(shard_data)
|
return self._gather_ipc_meta(shard_data)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user