[hardware] unify usage of is_tpu to current_platform.is_tpu() (#7102)
This commit is contained in:
parent
7025b11d94
commit
4d2dc5072b
@ -10,8 +10,7 @@ import vllm.envs as envs
|
|||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
|
from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu
|
||||||
is_tpu, is_xpu)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -194,7 +193,7 @@ def which_attn_to_use(
|
|||||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||||
return _Backend.IPEX
|
return _Backend.IPEX
|
||||||
|
|
||||||
if is_tpu():
|
if current_platform.is_tpu():
|
||||||
if selected_backend != _Backend.PALLAS:
|
if selected_backend != _Backend.PALLAS:
|
||||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||||
return _Backend.PALLAS
|
return _Backend.PALLAS
|
||||||
|
|||||||
@ -10,11 +10,12 @@ import vllm.envs as envs
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.tracing import is_otel_installed
|
from vllm.tracing import is_otel_installed
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
|
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
|
||||||
cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||||
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
|
is_hip, is_neuron, is_openvino, is_xpu,
|
||||||
print_warning_once)
|
print_warning_once)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -282,7 +283,7 @@ class ModelConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.quantization} quantization is currently not "
|
f"{self.quantization} quantization is currently not "
|
||||||
f"supported in ROCm.")
|
f"supported in ROCm.")
|
||||||
if is_tpu(
|
if current_platform.is_tpu(
|
||||||
) and self.quantization not in tpu_supported_quantization:
|
) and self.quantization not in tpu_supported_quantization:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.quantization} quantization is currently not "
|
f"{self.quantization} quantization is currently not "
|
||||||
@ -910,7 +911,7 @@ class DeviceConfig:
|
|||||||
self.device_type = "neuron"
|
self.device_type = "neuron"
|
||||||
elif is_openvino():
|
elif is_openvino():
|
||||||
self.device_type = "openvino"
|
self.device_type = "openvino"
|
||||||
elif is_tpu():
|
elif current_platform.is_tpu():
|
||||||
self.device_type = "tpu"
|
self.device_type = "tpu"
|
||||||
elif is_cpu():
|
elif is_cpu():
|
||||||
self.device_type = "cpu"
|
self.device_type = "cpu"
|
||||||
|
|||||||
@ -2,8 +2,9 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||||
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
|
from vllm.utils import get_ip, is_hip, is_xpu
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -111,7 +112,7 @@ def initialize_ray_cluster(
|
|||||||
# Placement group is already set.
|
# Placement group is already set.
|
||||||
return
|
return
|
||||||
|
|
||||||
device_str = "GPU" if not is_tpu() else "TPU"
|
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
|
||||||
# Create placement group for worker processes
|
# Create placement group for worker processes
|
||||||
current_placement_group = ray.util.get_current_placement_group()
|
current_placement_group = ray.util.get_current_placement_group()
|
||||||
if current_placement_group:
|
if current_placement_group:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import is_cpu, is_hip, is_xpu
|
||||||
|
|
||||||
|
|
||||||
class CustomOp(nn.Module):
|
class CustomOp(nn.Module):
|
||||||
@ -54,7 +55,7 @@ class CustomOp(nn.Module):
|
|||||||
return self.forward_hip
|
return self.forward_hip
|
||||||
elif is_cpu():
|
elif is_cpu():
|
||||||
return self.forward_cpu
|
return self.forward_cpu
|
||||||
elif is_tpu():
|
elif current_platform.is_tpu():
|
||||||
return self.forward_tpu
|
return self.forward_tpu
|
||||||
elif is_xpu():
|
elif is_xpu():
|
||||||
return self.forward_xpu
|
return self.forward_xpu
|
||||||
|
|||||||
@ -28,7 +28,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.utils import is_tpu
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -78,7 +78,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
cache = self._compute_cos_sin_cache()
|
cache = self._compute_cos_sin_cache()
|
||||||
self.use_native2 = is_tpu() and is_neox_style
|
self.use_native2 = current_platform.is_tpu() and is_neox_style
|
||||||
if not self.use_native2:
|
if not self.use_native2:
|
||||||
cache = cache.to(dtype)
|
cache = cache.to(dtype)
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
|
|||||||
supports_vision)
|
supports_vision)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_pin_memory_available, is_tpu
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -94,7 +94,7 @@ def _get_quantization_config(
|
|||||||
"""Get the quantization config."""
|
"""Get the quantization config."""
|
||||||
if model_config.quantization is not None:
|
if model_config.quantization is not None:
|
||||||
quant_config = get_quant_config(model_config, load_config)
|
quant_config = get_quant_config(model_config, load_config)
|
||||||
if not is_tpu():
|
if not current_platform.is_tpu():
|
||||||
capability = current_platform.get_device_capability()
|
capability = current_platform.get_device_capability()
|
||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
if capability < quant_config.get_min_capability():
|
if capability < quant_config.get_min_capability():
|
||||||
@ -320,7 +320,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
else:
|
else:
|
||||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
weights_iterator = pt_weights_iterator(hf_weights_files)
|
||||||
|
|
||||||
if is_tpu():
|
if current_platform.is_tpu():
|
||||||
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
||||||
# not too many ops are accumulated in the XLA program.
|
# not too many ops are accumulated in the XLA program.
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|||||||
@ -1,22 +1,25 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.utils import is_tpu
|
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
||||||
|
|
||||||
current_platform: Optional[Platform]
|
current_platform: Platform
|
||||||
|
|
||||||
if torch.version.cuda is not None:
|
try:
|
||||||
|
import libtpu
|
||||||
|
except ImportError:
|
||||||
|
libtpu = None
|
||||||
|
|
||||||
|
if libtpu is not None:
|
||||||
|
# people might install pytorch built with cuda but run on tpu
|
||||||
|
# so we need to check tpu first
|
||||||
|
from .tpu import TpuPlatform
|
||||||
|
current_platform = TpuPlatform()
|
||||||
|
elif torch.version.cuda is not None:
|
||||||
from .cuda import CudaPlatform
|
from .cuda import CudaPlatform
|
||||||
current_platform = CudaPlatform()
|
current_platform = CudaPlatform()
|
||||||
elif torch.version.hip is not None:
|
elif torch.version.hip is not None:
|
||||||
from .rocm import RocmPlatform
|
from .rocm import RocmPlatform
|
||||||
current_platform = RocmPlatform()
|
current_platform = RocmPlatform()
|
||||||
elif is_tpu():
|
|
||||||
from .tpu import TpuPlatform
|
|
||||||
current_platform = TpuPlatform()
|
|
||||||
else:
|
else:
|
||||||
current_platform = UnspecifiedPlatform()
|
current_platform = UnspecifiedPlatform()
|
||||||
|
|
||||||
|
|||||||
@ -333,15 +333,6 @@ def is_neuron() -> bool:
|
|||||||
return transformers_neuronx is not None
|
return transformers_neuronx is not None
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
|
||||||
def is_tpu() -> bool:
|
|
||||||
try:
|
|
||||||
import libtpu
|
|
||||||
except ImportError:
|
|
||||||
libtpu = None
|
|
||||||
return libtpu is not None
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def is_xpu() -> bool:
|
def is_xpu() -> bool:
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user