[misc][distributed] fix benign error in is_in_the_same_node (#5512)
This commit is contained in:
parent
77490c6f2f
commit
d1c3d7d139
@ -23,8 +23,9 @@ import contextlib
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import resource_tracker, shared_memory
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
@ -744,7 +745,12 @@ def is_in_the_same_node(pg: ProcessGroup):
|
||||
src=ranks[0],
|
||||
group=pg)
|
||||
name = recv[0]
|
||||
shm = shared_memory.SharedMemory(name=name)
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch("multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None):
|
||||
shm = shared_memory.SharedMemory(name=name)
|
||||
if shm.buf[:len(magic_message)] == magic_message:
|
||||
is_in_the_same_node[rank] = 1
|
||||
except Exception as e:
|
||||
@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup):
|
||||
|
||||
# clean up the shared memory segment
|
||||
with contextlib.suppress(OSError):
|
||||
if rank == 0:
|
||||
if shm:
|
||||
shm.unlink()
|
||||
else:
|
||||
if shm:
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
resource_tracker.unregister(
|
||||
shm._name, "shared_memory") # type: ignore[attr-defined]
|
||||
if rank == 0 and shm:
|
||||
shm.unlink()
|
||||
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||
|
||||
return is_in_the_same_node.sum().item() == world_size
|
||||
|
||||
Loading…
Reference in New Issue
Block a user