[misc][distributed] fix benign error in is_in_the_same_node (#5512)

This commit is contained in:
youkaichao 2024-06-14 10:59:28 -07:00 committed by GitHub
parent 77490c6f2f
commit d1c3d7d139
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,8 +23,9 @@ import contextlib
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import resource_tracker, shared_memory from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch import torch
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
@ -744,7 +745,12 @@ def is_in_the_same_node(pg: ProcessGroup):
src=ranks[0], src=ranks[0],
group=pg) group=pg)
name = recv[0] name = recv[0]
shm = shared_memory.SharedMemory(name=name) # fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
shm = shared_memory.SharedMemory(name=name)
if shm.buf[:len(magic_message)] == magic_message: if shm.buf[:len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1 is_in_the_same_node[rank] = 1
except Exception as e: except Exception as e:
@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup):
# clean up the shared memory segment # clean up the shared memory segment
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
if rank == 0: if rank == 0 and shm:
if shm: shm.unlink()
shm.unlink()
else:
if shm:
# fix to https://stackoverflow.com/q/62748654/9191338
resource_tracker.unregister(
shm._name, "shared_memory") # type: ignore[attr-defined]
torch.distributed.all_reduce(is_in_the_same_node, group=pg) torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return is_in_the_same_node.sum().item() == world_size return is_in_the_same_node.sum().item() == world_size