[Platforms] Add device_type in Platform (#10508)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2024-11-21 12:44:20 +08:00 committed by GitHub
parent 6c1208d083
commit 9d827170a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 11 additions and 15 deletions

View File

@ -1193,21 +1193,8 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None: def __init__(self, device: str = "auto") -> None:
if device == "auto": if device == "auto":
# Automated device type detection # Automated device type detection
if current_platform.is_cuda_alike(): self.device_type = current_platform.device_type
self.device_type = "cuda" if self.device_type is None:
elif current_platform.is_neuron():
self.device_type = "neuron"
elif current_platform.is_hpu():
self.device_type = "hpu"
elif current_platform.is_openvino():
self.device_type = "openvino"
elif current_platform.is_tpu():
self.device_type = "tpu"
elif current_platform.is_cpu():
self.device_type = "cpu"
elif current_platform.is_xpu():
self.device_type = "xpu"
else:
raise RuntimeError("Failed to infer device type") raise RuntimeError("Failed to infer device type")
else: else:
# Device type is assigned explicitly # Device type is assigned explicitly

View File

@ -19,6 +19,7 @@ logger = init_logger(__name__)
class CpuPlatform(Platform): class CpuPlatform(Platform):
_enum = PlatformEnum.CPU _enum = PlatformEnum.CPU
device_type: str = "cpu"
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:

View File

@ -109,6 +109,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class CudaPlatform(Platform): class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA _enum = PlatformEnum.CUDA
device_type: str = "cuda"
@classmethod @classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:

View File

@ -5,6 +5,7 @@ from .interface import Platform, PlatformEnum, _Backend
class HpuPlatform(Platform): class HpuPlatform(Platform):
_enum = PlatformEnum.HPU _enum = PlatformEnum.HPU
device_type: str = "hpu"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -56,6 +56,7 @@ class DeviceCapability(NamedTuple):
class Platform: class Platform:
_enum: PlatformEnum _enum: PlatformEnum
device_type: str
def is_cuda(self) -> bool: def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA return self._enum == PlatformEnum.CUDA

View File

@ -3,6 +3,7 @@ from .interface import Platform, PlatformEnum
class NeuronPlatform(Platform): class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON _enum = PlatformEnum.NEURON
device_type: str = "neuron"
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:

View File

@ -10,6 +10,7 @@ logger = init_logger(__name__)
class OpenVinoPlatform(Platform): class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO _enum = PlatformEnum.OPENVINO
device_type: str = "openvino"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -29,6 +29,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
device_type: str = "cuda"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -16,6 +16,7 @@ logger = init_logger(__name__)
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU
device_type: str = "tpu"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:

View File

@ -16,6 +16,7 @@ logger = init_logger(__name__)
class XPUPlatform(Platform): class XPUPlatform(Platform):
_enum = PlatformEnum.XPU _enum = PlatformEnum.XPU
device_type: str = "xpu"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: