[Hardware][ROCM] using current_platform.is_rocm (#9642)

Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
wangshuai09 2024-10-28 12:07:00 +08:00 committed by GitHub
parent 34a9941620
commit 4e2d95e372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 165 additions and 151 deletions

View File

@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest
from vllm import LLM
from vllm.utils import is_hip
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..models.utils import check_outputs_equal
@ -51,7 +51,7 @@ def test_models(
enforce_eager: bool,
) -> None:
if backend == "FLASHINFER" and is_hip():
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend

View File

@ -5,7 +5,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip
from vllm.platforms import current_platform
TEST_MODELS = [
("facebook/opt-125m", {}),
@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"):
"quantization": "marlin"
}))
if not is_hip() and is_quant_method_supported("awq"):
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))

View File

@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union
import torch
from vllm.utils import is_hip
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
else torch.float8_e4m3fn
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

View File

@ -6,11 +6,12 @@ import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
if not is_hip():
if not current_platform.is_rocm():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16]
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize(
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@ -317,8 +320,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(is_hip(),
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention(
@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
scale,
dtype,
)
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)

View File

@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch):
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"

View File

@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
@ -316,8 +317,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.

View File

@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.utils import is_hip
from vllm.platforms import current_platform
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
@ -82,7 +82,7 @@ class TestResources(NamedTuple):
will leverage attn_backend for the purpose of
constructing backend-compatible attention
metadata instances
Attributes:
* scale: 1/sqrt(d) scale factor for attn
@ -105,10 +105,10 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
Build key components for performing encoder/decoder attention test.
Note that
(1) The Attention instance constructed here, automatically selects
(1) The Attention instance constructed here, automatically selects
an attention backend class based on platform info & a set of canned
heuristics, so
(2) The attention backend instance constructed here is thus *not
(2) The attention backend instance constructed here is thus *not
the same backend instance* used by attn, but rather it is
intended to be a *different instance* of the *same backend class*;
therefore,
@ -156,7 +156,7 @@ def _encoder_attn_setup(
'''
Set up test vectors & data structures for encoder attention test.
A triplet of synthetic query/key/value tensors are constructed.
A triplet of synthetic query/key/value tensors are constructed.
Given this is an encoder attention test, the key & value
sequences will have the same length as the corresponding queries.
@ -169,14 +169,14 @@ def _encoder_attn_setup(
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
Returns:
* PhaseTestParameters data structure comprising (1) packed query/key/value
tensors, (2) the ideal output of attention computed using a naive
implementation, and (3) KVCache field set to None
@ -265,7 +265,7 @@ def _decoder_attn_setup(
Arguments:
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
@ -275,14 +275,14 @@ def _decoder_attn_setup(
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x
head_size) query/key/value tensors
* Prefill-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for prefill phase.
* Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
* Decode-phase decoder self-attention PhaseTestParameters data structure,
including (1) packed (number_of_tokens x num_heads x head_size)
query/key/value tensors along with (2) ideal attention output
computed using a naive implementation, and (3) memory-mapping data
structures appropriate for decode phase.
* max_block_idx: max physical address in decoder self-attention block-table
(intended to be used as the base address for the encoder/
@ -436,12 +436,12 @@ def _enc_dec_cross_attn_setup_reuses_query(
This function also constructs the cross-attention KV cache memory mapping
(slot mapping and block table), ensuring that the block table starts at
block_base_addr.
block_base_addr.
Arguments:
* decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
num_heads x head_size) decoder self-attention inputs;
num_heads x head_size) decoder self-attention inputs;
this function relies on the query and q_seq_lens
fields
* encoder_test_params: PhaseTestParameters data structure which was
@ -452,7 +452,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
self-attention; all fields
including KV cache required
* test_pt: TestPoint data structure; this function relies on the
following fields: batch_size, num_heads, head_size,
following fields: batch_size, num_heads, head_size,
block_size, max_q_seq_len
* test_rsrcs: TestResources data structure; this function relies on the
scale field
@ -460,16 +460,16 @@ def _enc_dec_cross_attn_setup_reuses_query(
Returns:
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
* Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for prefill phase.
* Decode-phase encoder/decoder cross-attention PhaseTestParameters data
structure, including (1) packed
(number_of_tokens x num_heads x head_size) query/key/value tensors
along with (2) ideal attention output computed using a
naive implementation, and (3) memory-mapping data structures appropriate
for decode phase.
'''
@ -596,7 +596,7 @@ def _run_encoder_attention_test(
'''
Run encoder attention.
attn.forward() is passed attn_type=AttentionType.ENCODER in order
attn.forward() is passed attn_type=AttentionType.ENCODER in order
to configure the kernel invocation for encoder attention
Requires attn_metadata.num_decode_tokens == 0
@ -607,7 +607,7 @@ def _run_encoder_attention_test(
* attn: Attention wrapper instance
* encoder_test_params: encoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
@ -646,7 +646,7 @@ def _run_decoder_self_attention_test(
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
@ -694,11 +694,11 @@ def _run_encoder_decoder_cross_attention_test(
and attn (Attention wrapper instance) fields
* decoder_test_params: decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
query field
* cross_test_params: encoder/decoder PhaseTestParameters data structure;
this function relies on the packed
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test(
attn_type=attn_type)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@ -755,7 +756,8 @@ def test_encoder_only(
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
@ -811,7 +813,8 @@ def test_encoder_only(
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@ -837,14 +840,14 @@ def test_e2e_enc_dec_attn(
attributes for prefill-phase, and (2) an analogous attention metadata
structure but for decode-phase
* Test attention steps in the following order
* Encoder attention
* Prefill self-attention
* Prefill cross-attention
* Decode self-attention
* Decode cross-attention
* Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention
* Besides being reflective of realistic use-cases, this order would
exacerbate any accidental overlap in the self-/cross-attention
block tables, which one hopes to avoid
@ -864,10 +867,11 @@ def test_e2e_enc_dec_attn(
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
of EncoderDecoderModelRunner, which constructs a single attention metadata

View File

@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import is_hip, seed_everything
from vllm.utils import seed_everything
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype):
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
n: int,
@ -256,7 +257,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_single_marlin_moe_multiply(
m: int,
n: int,

View File

@ -4,7 +4,7 @@ import pytest
import vllm
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
from vllm.platforms import current_platform
MODEL_PATH = "google/gemma-7b"
@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,

View File

@ -8,7 +8,7 @@ import pytest
import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
from vllm.platforms import current_platform
@dataclass
@ -19,7 +19,7 @@ class ModelWithQuantization:
MODELS: List[ModelWithQuantization]
#AWQ quantization is currently not supported in ROCm.
if is_hip():
if current_platform.is_rocm():
MODELS = [
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",

View File

@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"]
# ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
@ -70,7 +71,7 @@ def run_test(
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
@ -151,7 +152,7 @@ def run_test(
pytest.param(
"float",
marks=pytest.mark.skipif(
is_hip(),
current_platform.is_rocm(),
reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half"

View File

@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
@ -56,7 +55,7 @@ if current_platform.is_cpu():
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

View File

@ -5,7 +5,7 @@ tensor parallelism.
import pytest
import torch
from vllm.utils import is_hip
from vllm.platforms import current_platform
from .conftest import run_equality_correctness_test_tp
@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
if current_platform.is_rocm():
pytest.skip("hip is not well-supported yet")
run_equality_correctness_test_tp("JackFram/llama-68m",
common_llm_kwargs,

View File

@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
cuda_device_count_stateless, get_open_port, is_hip)
cuda_device_count_stateless, get_open_port)
if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage,
@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
if is_hip():
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10

View File

@ -659,11 +659,11 @@ def scaled_fp8_quant(
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
@ -674,8 +674,8 @@ def scaled_fp8_quant(
assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
else torch.float8_e4m3fn
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
if current_platform.is_rocm() else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)

View File

@ -3,7 +3,6 @@ import math
import torch
from vllm.platforms import current_platform
from vllm.utils import is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
@ -32,8 +31,9 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
):
super().__init__()
if use_spda is None:
use_spda = is_hip() or current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
use_spda = current_platform.is_rocm() or \
current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device)

View File

@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip
from vllm.utils import STR_BACKEND_ENV_VAR
logger = init_logger(__name__)
@ -208,7 +208,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if is_hip():
if current_platform.is_rocm():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)

View File

@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, print_warning_once)
print_warning_once)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -43,7 +43,7 @@ class ModelConfig:
Args:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
@ -99,15 +99,15 @@ class ModelConfig:
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality
limit_mm_per_prompt: Maximum number of data instances per modality
per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
@ -350,7 +350,7 @@ class ModelConfig:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if is_hip(
if current_platform.is_rocm(
) and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
@ -365,7 +365,7 @@ class ModelConfig:
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
if (self.quantization == "awq" and is_hip()
if (self.quantization == "awq" and current_platform.is_rocm()
and not envs.VLLM_USE_TRITON_AWQ):
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
@ -385,7 +385,7 @@ class ModelConfig:
def _verify_bnb_config(self) -> None:
"""
The current version of bitsandbytes (0.44.0) with 8-bit models does not
The current version of bitsandbytes (0.44.0) with 8-bit models does not
yet support CUDA graph.
"""
is_bitsandbytes = self.quantization == "bitsandbytes"
@ -810,7 +810,7 @@ class LoadConfig:
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
@ -843,7 +843,8 @@ class LoadConfig:
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
if current_platform.is_rocm(
) and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
@ -967,7 +968,7 @@ class ParallelConfig:
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
if is_hip():
if current_platform.is_rocm():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
@ -996,7 +997,7 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
preemption_mode: Whether to perform preemption by swapping or
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
@ -1215,7 +1216,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
@ -1225,7 +1226,7 @@ class SpeculativeConfig:
If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
@ -1470,13 +1471,13 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will not
be returned even if requested by sampling parameters. This
be returned even if requested by sampling parameters. This
reduces latency by skipping logprob calculation in proposal
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
@ -1843,10 +1844,10 @@ def get_min_sliding_window(
def get_served_model_name(model: str,
served_model_name: Optional[Union[str, List[str]]]):
"""
If the input is a non-empty list, the first model_name in
`served_model_name` is taken.
If the input is a non-empty string, it is used directly.
For cases where the input is either an empty string or an
If the input is a non-empty list, the first model_name in
`served_model_name` is taken.
If the input is a non-empty string, it is used directly.
For cases where the input is either an empty string or an
empty list, the fallback is to use `self.model`.
"""
if not served_model_name:

View File

@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip
from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -231,7 +231,7 @@ def initialize_ray_cluster(
assert_ray_available()
# Connect to a ray cluster.
if is_hip() or current_platform.is_xpu():
if current_platform.is_rocm() or current_platform.is_xpu():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)

View File

@ -7,7 +7,7 @@ import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once
from vllm.utils import print_warning_once
logger = init_logger(__name__)
@ -72,7 +72,7 @@ class CustomOp(nn.Module):
if not enabled:
return self.forward_native
if is_hip():
if current_platform.is_rocm():
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu

View File

@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip, print_warning_once
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
class GPTQMarlinState(Enum):
@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
if current_platform.is_rocm():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.utils import is_hip
from vllm.platforms import current_platform
__all__ = ["CompressedTensorsW8A8Fp8"]
@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths=layer.logical_widths,
)
if is_hip():
if current_platform.is_rocm():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
if is_hip():
if current_platform.is_rocm():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,

View File

@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform
from vllm.utils import is_hip
logger = init_logger(__name__)
@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight = layer.weight
if is_hip():
if current_platform.is_rocm():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,

View File

@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"]
@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
if current_platform.is_rocm():
self.use_marlin = False
def create_weights(
@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale = layer.weight_scale
# If rocm, use float8_e4m3fnuz.
if is_hip():
if current_platform.is_rocm():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn
if current_platform.is_rocm() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
if current_platform.is_rocm():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(

View File

@ -4,16 +4,16 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import is_hip
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \
if current_platform.is_rocm() else None
def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
if is_hip():
if current_platform.is_rocm():
return False
capability_tuple = current_platform.get_device_capability()

View File

@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn
if is_hip():
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting

View File

@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting

View File

@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
@ -423,7 +423,7 @@ class LlamaModel(nn.Module):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if is_hip():
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting

View File

@ -12,7 +12,7 @@ import cloudpickle
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
from vllm.platforms import current_platform
from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp)
@ -247,7 +247,7 @@ def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
if is_hip():
if current_platform.is_rocm():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")

View File

@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting

View File

@ -314,10 +314,6 @@ class PyObjectCache:
self._index = 0
def is_hip() -> bool:
return torch.version.hip is not None
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
if not torch.cuda._is_compiled():
return 0
if is_hip():
if current_platform.is_rocm():
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = torch.cuda._device_count_amdsmi() if (hasattr(

View File

@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available,
flatten_2d_lists, is_pin_memory_available,
supports_dynamo, weak_ref_tensor)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
@ -737,13 +738,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
family of functions.
Args:
num_seqs (int): Number of sequences scheduled to run.
num_seqs (int): Number of sequences scheduled to run.
max_decode_seq_len (int): Greatest of all the decode sequence
lengths. Used only in checking the viablility of using
CUDA graphs.
max_encoder_seq_len (int, optional): Greatest of all the encode
sequence lengths. Defaults to 0. Used only in checking the
viability of using CUDA graphs.
viability of using CUDA graphs.
Returns:
int: Returns the determined number of padding sequences. If
CUDA graphs is not viable, returns -1.
@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.kv_cache_dtype == "fp8" and is_hip():
if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.