[misc][distributed] fix benign error in is_in_the_same_node (#5512)
This commit is contained in:
parent
77490c6f2f
commit
d1c3d7d139
@ -23,8 +23,9 @@ import contextlib
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import resource_tracker, shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import Backend, ProcessGroup
|
from torch.distributed import Backend, ProcessGroup
|
||||||
@ -744,7 +745,12 @@ def is_in_the_same_node(pg: ProcessGroup):
|
|||||||
src=ranks[0],
|
src=ranks[0],
|
||||||
group=pg)
|
group=pg)
|
||||||
name = recv[0]
|
name = recv[0]
|
||||||
shm = shared_memory.SharedMemory(name=name)
|
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||||
|
# Python incorrectly tracks shared memory even if it is not
|
||||||
|
# created by the process. The following patch is a workaround.
|
||||||
|
with patch("multiprocessing.resource_tracker.register",
|
||||||
|
lambda *args, **kwargs: None):
|
||||||
|
shm = shared_memory.SharedMemory(name=name)
|
||||||
if shm.buf[:len(magic_message)] == magic_message:
|
if shm.buf[:len(magic_message)] == magic_message:
|
||||||
is_in_the_same_node[rank] = 1
|
is_in_the_same_node[rank] = 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup):
|
|||||||
|
|
||||||
# clean up the shared memory segment
|
# clean up the shared memory segment
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
if rank == 0:
|
if rank == 0 and shm:
|
||||||
if shm:
|
shm.unlink()
|
||||||
shm.unlink()
|
|
||||||
else:
|
|
||||||
if shm:
|
|
||||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
|
||||||
resource_tracker.unregister(
|
|
||||||
shm._name, "shared_memory") # type: ignore[attr-defined]
|
|
||||||
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||||
|
|
||||||
return is_in_the_same_node.sum().item() == world_size
|
return is_in_the_same_node.sum().item() == world_size
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user