[misc] use nvml to get consistent device name (#7582)

This commit is contained in:
youkaichao 2024-08-16 21:15:13 -07:00 committed by GitHub
parent 7c0b7ea214
commit eed020f673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 46 additions and 2 deletions

View File

@ -11,6 +11,7 @@ import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -287,7 +288,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"

View File

@ -44,6 +44,35 @@ def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_name(device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)
@with_nvml_context
def warn_if_different_devices():
device_ids: int = pynvml.nvmlDeviceGetCount()
if device_ids > 1:
device_names = [get_physical_device_name(i) for i in range(device_ids)]
if len(set(device_names)) > 1 and os.environ.get(
"CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
logger.warning(
"Detected different devices in the system: \n%s\nPlease"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.", "\n".join(device_names))
try:
from sphinx.ext.autodoc.mock import _MockModule
if not isinstance(pynvml, _MockModule):
warn_if_different_devices()
except ModuleNotFoundError:
warn_if_different_devices()
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
@ -61,6 +90,11 @@ class CudaPlatform(Platform):
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id)
@staticmethod
def get_device_name(device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id)
@staticmethod
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:

View File

@ -27,6 +27,10 @@ class Platform:
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError
@staticmethod
def get_device_name(device_id: int = 0) -> str:
raise NotImplementedError
@staticmethod
def inference_mode():
"""A device-specific wrapper of `torch.inference_mode`.

View File

@ -13,3 +13,8 @@ class RocmPlatform(Platform):
@lru_cache(maxsize=8)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device_id)
@staticmethod
@lru_cache(maxsize=8)
def get_device_name(device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)

View File

@ -368,7 +368,7 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16:
compute_capability = current_platform.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
gpu_name = current_platform.get_device_name()
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "