[Core][Distributed] fix _is_full_nvlink detection (#4233)

This commit is contained in:
youkaichao 2024-04-21 23:04:16 -07:00 committed by GitHub
parent 95e5b087cf
commit 747b1a7147
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -53,14 +54,20 @@ def init_custom_ar() -> None:
return False return False
# test nvlink first, this will filter out most of the cases # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported # where custom allreduce is not supported
full_nvlink = _is_full_nvlink(rank, world_size) if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = list(
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink: if world_size > 2 and not full_nvlink:
logger.warn( logger.warn(
"Custom allreduce is disabled because it's not supported on more" "Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify" " than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.") " disable_custom_all_reduce=True explicitly.")
return return
# test P2P capability # test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time # this is expensive to compute at the first time
# then we cache the result # then we cache the result
if not _can_p2p(rank, world_size): if not _can_p2p(rank, world_size):
@ -138,23 +145,28 @@ def _nvml():
pynvml.nvmlShutdown() pynvml.nvmlShutdown()
# query if the set of gpus are fully connected by nvlink (1 hop)
@_nvml() @_nvml()
def _is_full_nvlink(rank, world_size): def _is_full_nvlink(device_ids: List[int]) -> bool:
handle = pynvml.nvmlDeviceGetHandleByIndex(rank) """
for i in range(world_size): query if the set of gpus are fully connected by nvlink (1 hop)
if i != rank: Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
try: so it works on real physical device ids.
peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i) """
p2p_status = pynvml.nvmlDeviceGetP2PStatus( handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) for i, handle in enumerate(handles):
if p2p_status != pynvml.NVML_P2P_STATUS_OK: for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False return False
except pynvml.NVMLError as error:
logger.info(
f"NVLink detection failed with message \"{str(error)}\". "
"This is normal if your machine has no NVLink equipped")
return False
return True return True