[misc] use nvml to get consistent device name (#7582)
This commit is contained in:
parent
7c0b7ea214
commit
eed020f673
@ -11,6 +11,7 @@ import triton.language as tl
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -287,7 +288,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||||
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
|
||||||
|
|
||||||
|
|||||||
@ -44,6 +44,35 @@ def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
|||||||
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
@with_nvml_context
|
||||||
|
def get_physical_device_name(device_id: int = 0) -> str:
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||||
|
return pynvml.nvmlDeviceGetName(handle)
|
||||||
|
|
||||||
|
|
||||||
|
@with_nvml_context
|
||||||
|
def warn_if_different_devices():
|
||||||
|
device_ids: int = pynvml.nvmlDeviceGetCount()
|
||||||
|
if device_ids > 1:
|
||||||
|
device_names = [get_physical_device_name(i) for i in range(device_ids)]
|
||||||
|
if len(set(device_names)) > 1 and os.environ.get(
|
||||||
|
"CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
|
||||||
|
logger.warning(
|
||||||
|
"Detected different devices in the system: \n%s\nPlease"
|
||||||
|
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
||||||
|
"avoid unexpected behavior.", "\n".join(device_names))
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sphinx.ext.autodoc.mock import _MockModule
|
||||||
|
|
||||||
|
if not isinstance(pynvml, _MockModule):
|
||||||
|
warn_if_different_devices()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
warn_if_different_devices()
|
||||||
|
|
||||||
|
|
||||||
def device_id_to_physical_device_id(device_id: int) -> int:
|
def device_id_to_physical_device_id(device_id: int) -> int:
|
||||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
||||||
@ -61,6 +90,11 @@ class CudaPlatform(Platform):
|
|||||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||||
return get_physical_device_capability(physical_device_id)
|
return get_physical_device_capability(physical_device_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_device_name(device_id: int = 0) -> str:
|
||||||
|
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||||
|
return get_physical_device_name(physical_device_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
|
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
|
||||||
|
|||||||
@ -27,6 +27,10 @@ class Platform:
|
|||||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_device_name(device_id: int = 0) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def inference_mode():
|
def inference_mode():
|
||||||
"""A device-specific wrapper of `torch.inference_mode`.
|
"""A device-specific wrapper of `torch.inference_mode`.
|
||||||
|
|||||||
@ -13,3 +13,8 @@ class RocmPlatform(Platform):
|
|||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||||
return torch.cuda.get_device_capability(device_id)
|
return torch.cuda.get_device_capability(device_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
def get_device_name(device_id: int = 0) -> str:
|
||||||
|
return torch.cuda.get_device_name(device_id)
|
||||||
|
|||||||
@ -368,7 +368,7 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
|||||||
if torch_dtype == torch.bfloat16:
|
if torch_dtype == torch.bfloat16:
|
||||||
compute_capability = current_platform.get_device_capability()
|
compute_capability = current_platform.get_device_capability()
|
||||||
if compute_capability[0] < 8:
|
if compute_capability[0] < 8:
|
||||||
gpu_name = torch.cuda.get_device_name()
|
gpu_name = current_platform.get_device_name()
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Bfloat16 is only supported on GPUs with compute capability "
|
"Bfloat16 is only supported on GPUs with compute capability "
|
||||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user