vllm/vllm/platforms/__init__.py
xendo 9dbcce84a7
[Neuron] [Bugfix] Fix neuron startup (#9374)
Co-authored-by: Jerzy Zagorski <jzagorsk@amazon.com>
2024-10-22 12:51:41 +00:00

92 lines
2.1 KiB
Python

from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.
is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass
is_cuda = False
try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
pass
is_rocm = False
try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass
is_xpu = False
try:
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
except Exception:
pass
is_cpu = False
try:
from importlib.metadata import version
is_cpu = "cpu" in version("vllm")
except Exception:
pass
is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
else:
current_platform = UnspecifiedPlatform()
__all__ = ['Platform', 'PlatformEnum', 'current_platform']