[misc][distributed] fix benign error in is_in_the_same_node (#5512)

This commit is contained in:
youkaichao 2024-06-14 10:59:28 -07:00 committed by GitHub
parent 77490c6f2f
commit d1c3d7d139
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,8 +23,9 @@ import contextlib
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import resource_tracker, shared_memory
from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
from torch.distributed import Backend, ProcessGroup
@ -744,7 +745,12 @@ def is_in_the_same_node(pg: ProcessGroup):
src=ranks[0],
group=pg)
name = recv[0]
shm = shared_memory.SharedMemory(name=name)
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
shm = shared_memory.SharedMemory(name=name)
if shm.buf[:len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1
except Exception as e:
@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup):
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == 0:
if shm:
shm.unlink()
else:
if shm:
# fix to https://stackoverflow.com/q/62748654/9191338
resource_tracker.unregister(
shm._name, "shared_memory") # type: ignore[attr-defined]
if rank == 0 and shm:
shm.unlink()
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return is_in_the_same_node.sum().item() == world_size