[Hardware][ROCM] using current_platform.is_rocm (#9642)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
parent
34a9941620
commit
4e2d95e372
@ -11,7 +11,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
from ..models.utils import check_outputs_equal
|
||||
@ -51,7 +51,7 @@ def test_models(
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
|
||||
if backend == "FLASHINFER" and is_hip():
|
||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
TEST_MODELS = [
|
||||
("facebook/opt-125m", {}),
|
||||
@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"):
|
||||
"quantization": "marlin"
|
||||
}))
|
||||
|
||||
if not is_hip() and is_quant_method_supported("awq"):
|
||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
||||
"quantization": "AWQ"
|
||||
}))
|
||||
|
||||
@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||
ROCM_FP8_MAX = 224.0
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
|
||||
else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
|
||||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||
else torch.finfo(quant_dtype)
|
||||
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
|
||||
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
else qtype_traits.min
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
s_512 = as_float32_tensor(512.0)
|
||||
@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
|
||||
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
|
||||
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
else fp8_traits.max
|
||||
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||
else fp8_traits.min
|
||||
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||
one = as_float32_tensor(1.0)
|
||||
|
||||
|
||||
@ -6,11 +6,12 @@ import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_max_shared_memory_bytes, seed_everything
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
if not is_hip():
|
||||
if not current_platform.is_rocm():
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float
|
||||
] if not is_hip() else [torch.half, torch.bfloat16]
|
||||
DTYPES = [
|
||||
torch.half, torch.bfloat16, torch.float
|
||||
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
|
||||
"version",
|
||||
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@ -317,8 +320,8 @@ def test_paged_attention(
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(is_hip(),
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
|
||||
scale,
|
||||
dtype,
|
||||
)
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
False)
|
||||
assert backend.name == "TORCH_SDPA"
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.is_hip", return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform.is_rocm",
|
||||
return_value=True):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "ROCM_FLASH"
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn)
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_max_shared_memory_bytes, seed_everything
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
@ -316,8 +317,8 @@ def test_paged_attention(
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
|
||||
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
|
||||
from vllm.attention.selector import (_Backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# List of support backends for encoder/decoder models
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
|
||||
@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test(
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@ -755,7 +756,8 @@ def test_encoder_only(
|
||||
No KV cache is required for encoder-only attention.
|
||||
|
||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||
AMD GPUs, therefore this test simply is skipped if
|
||||
current_platform.is_rocm().
|
||||
|
||||
This test globally forces an override of the usual backend
|
||||
auto-selection process, forcing the specific backend-under-test
|
||||
@ -811,7 +813,8 @@ def test_encoder_only(
|
||||
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@ -864,7 +867,8 @@ def test_e2e_enc_dec_attn(
|
||||
to be utilized.
|
||||
|
||||
Note on ROCm/HIP: currently encoder/decoder models are not supported on
|
||||
AMD GPUs, therefore this test simply is skipped if is_hip().
|
||||
AMD GPUs, therefore this test simply is skipped if
|
||||
current_platform.is_rocm().
|
||||
|
||||
Note on metadata: there is a single attention metadata structure shared by
|
||||
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
|
||||
|
||||
@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import is_hip, seed_everything
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||
@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@ -256,7 +257,7 @@ def test_fused_marlin_moe(
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_single_marlin_moe_multiply(
|
||||
m: int,
|
||||
n: int,
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_PATH = "google/gemma-7b"
|
||||
|
||||
@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
|
||||
@pytest.mark.xfail(current_platform.is_rocm(),
|
||||
reason="There can be output mismatch on ROCm")
|
||||
def test_gemma_lora(gemma_lora_files):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
import vllm
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -19,7 +19,7 @@ class ModelWithQuantization:
|
||||
|
||||
MODELS: List[ModelWithQuantization]
|
||||
#AWQ quantization is currently not supported in ROCm.
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
MODELS = [
|
||||
ModelWithQuantization(
|
||||
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
|
||||
@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ...utils import check_logprobs_close
|
||||
@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"]
|
||||
# ROCm Triton FA can run into compilation issues with these models due to,
|
||||
# excessive use of shared memory. Use other backends in the meantime.
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
@ -151,7 +152,7 @@ def run_test(
|
||||
pytest.param(
|
||||
"float",
|
||||
marks=pytest.mark.skipif(
|
||||
is_hip(),
|
||||
current_platform.is_rocm(),
|
||||
reason=
|
||||
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
|
||||
), "half"
|
||||
|
||||
@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
@ -56,7 +55,7 @@ if current_platform.is_cpu():
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ tensor parallelism.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .conftest import run_equality_correctness_test_tp
|
||||
|
||||
@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
"""Verify greedy equality when tensor parallelism is used.
|
||||
"""
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("hip is not well-supported yet")
|
||||
run_equality_correctness_test_tp("JackFram/llama-68m",
|
||||
common_llm_kwargs,
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
|
||||
cuda_device_count_stateless, get_open_port, is_hip)
|
||||
cuda_device_count_stateless, get_open_port)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from amdsmi import (amdsmi_get_gpu_vram_usage,
|
||||
@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||
output: Dict[int, str] = {}
|
||||
output_raw: Dict[int, float] = {}
|
||||
for device in devices:
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
dev_handle = amdsmi_get_processor_handles()[device]
|
||||
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
|
||||
gb_used = mem_info["vram_used"] / 2**10
|
||||
|
||||
@ -674,8 +674,8 @@ def scaled_fp8_quant(
|
||||
assert (input.ndim == 2)
|
||||
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
||||
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
||||
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
|
||||
else torch.float8_e4m3fn
|
||||
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
||||
if current_platform.is_rocm() else torch.float8_e4m3fn
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
|
||||
@ -3,7 +3,6 @@ import math
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
@ -32,8 +31,9 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
if use_spda is None:
|
||||
use_spda = is_hip() or current_platform.is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
use_spda = current_platform.is_rocm() or \
|
||||
current_platform.is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
device = device or (torch.cuda.current_device()
|
||||
if current_platform.is_cuda_alike() else "cpu")
|
||||
device = torch.device(device)
|
||||
|
||||
@ -10,7 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -208,7 +208,7 @@ def which_attn_to_use(
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# AMD GPUs.
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
||||
get_hf_image_processor_config,
|
||||
get_hf_text_config)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
is_hip, print_warning_once)
|
||||
print_warning_once)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -350,7 +350,7 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if is_hip(
|
||||
if current_platform.is_rocm(
|
||||
) and self.quantization not in rocm_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
@ -365,7 +365,7 @@ class ModelConfig:
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.", self.quantization)
|
||||
if (self.quantization == "awq" and is_hip()
|
||||
if (self.quantization == "awq" and current_platform.is_rocm()
|
||||
and not envs.VLLM_USE_TRITON_AWQ):
|
||||
logger.warning(
|
||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||
@ -843,7 +843,8 @@ class LoadConfig:
|
||||
self.load_format = LoadFormat(load_format)
|
||||
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
if current_platform.is_rocm(
|
||||
) and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f for f in LoadFormat.__members__
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
@ -967,7 +968,7 @@ class ParallelConfig:
|
||||
if self.use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import get_ip, is_hip
|
||||
from vllm.utils import get_ip
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -231,7 +231,7 @@ def initialize_ray_cluster(
|
||||
assert_ray_available()
|
||||
|
||||
# Connect to a ray cluster.
|
||||
if is_hip() or current_platform.is_xpu():
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
|
||||
@ -7,7 +7,7 @@ import vllm.envs as envs
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -72,7 +72,7 @@ class CustomOp(nn.Module):
|
||||
if not enabled:
|
||||
return self.forward_native
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
return self.forward_hip
|
||||
elif current_platform.is_cpu():
|
||||
return self.forward_cpu
|
||||
|
||||
@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
|
||||
class GPTQMarlinState(Enum):
|
||||
@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
|
||||
# If rocm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight = layer.weight
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
|
||||
@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
weight = layer.weight
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
def create_weights(
|
||||
@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If rocm, use float8_e4m3fnuz.
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz \
|
||||
if is_hip() else torch.float8_e4m3fn
|
||||
if current_platform.is_rocm() else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
# If rocm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
|
||||
@ -4,16 +4,16 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \
|
||||
if current_platform.is_rocm() else None
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
# cutlass is not supported on Rocm
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
return False
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.transformer.h[layer_idx].attn
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
|
||||
@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if not isinstance(self.model.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
|
||||
@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
@ -423,7 +423,7 @@ class LlamaModel(nn.Module):
|
||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
|
||||
@ -12,7 +12,7 @@ import cloudpickle
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_multimodal, supports_pp)
|
||||
@ -247,7 +247,7 @@ def _try_load_model_cls(
|
||||
model_arch: str,
|
||||
model: _BaseRegisteredModel,
|
||||
) -> Optional[Type[nn.Module]]:
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(f"Model architecture '{model_arch}' is not "
|
||||
"supported by ROCm for now.")
|
||||
|
||||
@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if not isinstance(self.model.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
|
||||
@ -314,10 +314,6 @@ class PyObjectCache:
|
||||
self._index = 0
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||
# This requires a sufficiently modern version of Torch 2.4.0
|
||||
raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
|
||||
|
||||
@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalInputs, MultiModalRegistry)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.prompt_adapter.worker_manager import (
|
||||
@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
|
||||
flatten_2d_lists, is_hip, is_pin_memory_available,
|
||||
flatten_2d_lists, is_pin_memory_available,
|
||||
supports_dynamo, weak_ref_tensor)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.prompt_adapter_manager.create_prompt_adapter_manager(
|
||||
self.model))
|
||||
|
||||
if self.kv_cache_dtype == "fp8" and is_hip():
|
||||
if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
|
||||
# Currently only ROCm accepts kv-cache scaling factors
|
||||
# via quantization_param_path and this will be deprecated
|
||||
# in the future.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user