[misc] use nvml to get consistent device name (#7582)
This commit is contained in:
parent
7c0b7ea214
commit
eed020f673
@ -11,6 +11,7 @@ import triton.language as tl
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -287,7 +288,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
|
||||
|
||||
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
||||
|
||||
|
||||
@ -44,6 +44,35 @@ def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_physical_device_name(device_id: int = 0) -> str:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetName(handle)
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def warn_if_different_devices():
|
||||
device_ids: int = pynvml.nvmlDeviceGetCount()
|
||||
if device_ids > 1:
|
||||
device_names = [get_physical_device_name(i) for i in range(device_ids)]
|
||||
if len(set(device_names)) > 1 and os.environ.get(
|
||||
"CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
|
||||
logger.warning(
|
||||
"Detected different devices in the system: \n%s\nPlease"
|
||||
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
||||
"avoid unexpected behavior.", "\n".join(device_names))
|
||||
|
||||
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
|
||||
if not isinstance(pynvml, _MockModule):
|
||||
warn_if_different_devices()
|
||||
except ModuleNotFoundError:
|
||||
warn_if_different_devices()
|
||||
|
||||
|
||||
def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
||||
@ -61,6 +90,11 @@ class CudaPlatform(Platform):
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_capability(physical_device_id)
|
||||
|
||||
@staticmethod
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_name(physical_device_id)
|
||||
|
||||
@staticmethod
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
|
||||
|
||||
@ -27,6 +27,10 @@ class Platform:
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
"""A device-specific wrapper of `torch.inference_mode`.
|
||||
|
||||
@ -13,3 +13,8 @@ class RocmPlatform(Platform):
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
return torch.cuda.get_device_capability(device_id)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
@ -368,7 +368,7 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16:
|
||||
compute_capability = current_platform.get_device_capability()
|
||||
if compute_capability[0] < 8:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
gpu_name = current_platform.get_device_name()
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||
|
||||
Loading…
Reference in New Issue
Block a user