[hardware] unify usage of is_tpu to current_platform.is_tpu() (#7102)

This commit is contained in:
youkaichao 2024-08-13 00:16:42 -07:00 committed by GitHub
parent 7025b11d94
commit 4d2dc5072b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 29 additions and 33 deletions

View File

@ -10,8 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
is_tpu, is_xpu)
from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu
logger = init_logger(__name__)
@ -194,7 +193,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
if is_tpu():
if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS

View File

@ -10,11 +10,12 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once)
if TYPE_CHECKING:
@ -282,7 +283,7 @@ class ModelConfig:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if is_tpu(
if current_platform.is_tpu(
) and self.quantization not in tpu_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
@ -910,7 +911,7 @@ class DeviceConfig:
self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"
elif is_tpu():
elif current_platform.is_tpu():
self.device_type = "tpu"
elif is_cpu():
self.device_type = "cpu"

View File

@ -2,8 +2,9 @@ from typing import List, Optional, Tuple, Union
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
from vllm.utils import get_ip, is_hip, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -111,7 +112,7 @@ def initialize_ray_cluster(
# Placement group is already set.
return
device_str = "GPU" if not is_tpu() else "TPU"
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:

View File

@ -1,6 +1,7 @@
import torch.nn as nn
from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
class CustomOp(nn.Module):
@ -54,7 +55,7 @@ class CustomOp(nn.Module):
return self.forward_hip
elif is_cpu():
return self.forward_cpu
elif is_tpu():
elif current_platform.is_tpu():
return self.forward_tpu
elif is_xpu():
return self.forward_xpu

View File

@ -28,7 +28,7 @@ import torch
import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_tpu
from vllm.platforms import current_platform
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@ -78,7 +78,7 @@ class RotaryEmbedding(CustomOp):
self.dtype = dtype
cache = self._compute_cos_sin_cache()
self.use_native2 = is_tpu() and is_neox_style
self.use_native2 = current_platform.is_tpu() and is_neox_style
if not self.use_native2:
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)

View File

@ -41,7 +41,7 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, is_tpu
from vllm.utils import is_pin_memory_available
@contextmanager
@ -94,7 +94,7 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
if not is_tpu():
if not current_platform.is_tpu():
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
@ -320,7 +320,7 @@ class DefaultModelLoader(BaseModelLoader):
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
if is_tpu():
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm

View File

@ -1,22 +1,25 @@
from typing import Optional
import torch
from vllm.utils import is_tpu
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Optional[Platform]
current_platform: Platform
if torch.version.cuda is not None:
try:
import libtpu
except ImportError:
libtpu = None
if libtpu is not None:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif torch.version.cuda is not None:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_tpu():
from .tpu import TpuPlatform
current_platform = TpuPlatform()
else:
current_platform = UnspecifiedPlatform()

View File

@ -333,15 +333,6 @@ def is_neuron() -> bool:
return transformers_neuronx is not None
@lru_cache(maxsize=None)
def is_tpu() -> bool:
try:
import libtpu
except ImportError:
libtpu = None
return libtpu is not None
@lru_cache(maxsize=None)
def is_xpu() -> bool:
from importlib.metadata import PackageNotFoundError, version