[Hardware][ROCM] using current_platform.is_rocm (#9642)

Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
wangshuai09 2024-10-28 12:07:00 +08:00 committed by GitHub
parent 34a9941620
commit 4e2d95e372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 165 additions and 151 deletions

View File

@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest
from vllm import LLM
from vllm.utils import is_hip
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from ..models.utils import check_outputs_equal
@ -51,7 +51,7 @@ def test_models(
enforce_eager: bool,
) -> None:
if backend == "FLASHINFER" and is_hip():
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend

View File

@ -5,7 +5,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip
from vllm.platforms import current_platform
TEST_MODELS = [
("facebook/opt-125m", {}),
@ -55,7 +55,7 @@ if is_quant_method_supported("marlin"):
"quantization": "marlin"
}))
if not is_hip() and is_quant_method_supported("awq"):
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))

View File

@ -2,12 +2,13 @@ from typing import Optional, Tuple, Union
import torch
from vllm.utils import is_hip
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
else torch.float8_e4m3fn
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
@ -24,8 +25,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
@ -66,8 +69,10 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

View File

@ -6,11 +6,12 @@ import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
if not is_hip():
if not current_platform.is_rocm():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16]
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize(
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@ -317,8 +320,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(is_hip(),
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention(
@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
scale,
dtype,
)
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)

View File

@ -25,7 +25,8 @@ def test_env(name: str, device: str, monkeypatch):
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"

View File

@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes, seed_everything
from .allclose_default import get_default_atol, get_default_rtol
@ -316,8 +317,8 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.

View File

@ -18,7 +18,7 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.utils import is_hip
from vllm.platforms import current_platform
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
@ -726,7 +726,8 @@ def _run_encoder_decoder_cross_attention_test(
attn_type=attn_type)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@ -755,7 +756,8 @@ def test_encoder_only(
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
@ -811,7 +813,8 @@ def test_encoder_only(
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.skipif(current_platform.is_rocm(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@ -864,7 +867,8 @@ def test_e2e_enc_dec_attn(
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
AMD GPUs, therefore this test simply is skipped if
current_platform.is_rocm().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),

View File

@ -18,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import is_hip, seed_everything
from vllm.utils import seed_everything
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@ -103,7 +104,7 @@ def test_mixtral_moe(dtype: torch.dtype):
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
n: int,
@ -256,7 +257,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
@pytest.mark.skipif(is_hip(), reason="Skip for rocm")
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_single_marlin_moe_multiply(
m: int,
n: int,

View File

@ -4,7 +4,7 @@ import pytest
import vllm
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
from vllm.platforms import current_platform
MODEL_PATH = "google/gemma-7b"
@ -31,7 +31,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,

View File

@ -8,7 +8,7 @@ import pytest
import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
from vllm.platforms import current_platform
@dataclass
@ -19,7 +19,7 @@ class ModelWithQuantization:
MODELS: List[ModelWithQuantization]
#AWQ quantization is currently not supported in ROCm.
if is_hip():
if current_platform.is_rocm():
MODELS = [
ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",

View File

@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"]
# ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
@ -151,7 +152,7 @@ def run_test(
pytest.param(
"float",
marks=pytest.mark.skipif(
is_hip(),
current_platform.is_rocm(),
reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half"

View File

@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
@ -56,7 +55,7 @@ if current_platform.is_cpu():
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

View File

@ -5,7 +5,7 @@ tensor parallelism.
import pytest
import torch
from vllm.utils import is_hip
from vllm.platforms import current_platform
from .conftest import run_equality_correctness_test_tp
@ -51,7 +51,7 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
if current_platform.is_rocm():
pytest.skip("hip is not well-supported yet")
run_equality_correctness_test_tp("JackFram/llama-68m",
common_llm_kwargs,

View File

@ -26,7 +26,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
cuda_device_count_stateless, get_open_port, is_hip)
cuda_device_count_stateless, get_open_port)
if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage,
@ -487,7 +487,7 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
if is_hip():
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10

View File

@ -674,8 +674,8 @@ def scaled_fp8_quant(
assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
else torch.float8_e4m3fn
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
if current_platform.is_rocm() else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)

View File

@ -3,7 +3,6 @@ import math
import torch
from vllm.platforms import current_platform
from vllm.utils import is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
@ -32,8 +31,9 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
):
super().__init__()
if use_spda is None:
use_spda = is_hip() or current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
use_spda = current_platform.is_rocm() or \
current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device)

View File

@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip
from vllm.utils import STR_BACKEND_ENV_VAR
logger = init_logger(__name__)
@ -208,7 +208,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if is_hip():
if current_platform.is_rocm():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)

View File

@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, print_warning_once)
print_warning_once)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -350,7 +350,7 @@ class ModelConfig:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if is_hip(
if current_platform.is_rocm(
) and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
@ -365,7 +365,7 @@ class ModelConfig:
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
if (self.quantization == "awq" and is_hip()
if (self.quantization == "awq" and current_platform.is_rocm()
and not envs.VLLM_USE_TRITON_AWQ):
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
@ -843,7 +843,8 @@ class LoadConfig:
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
if current_platform.is_rocm(
) and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
@ -967,7 +968,7 @@ class ParallelConfig:
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
if is_hip():
if current_platform.is_rocm():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "

View File

@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip
from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -231,7 +231,7 @@ def initialize_ray_cluster(
assert_ray_available()
# Connect to a ray cluster.
if is_hip() or current_platform.is_xpu():
if current_platform.is_rocm() or current_platform.is_xpu():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)

View File

@ -7,7 +7,7 @@ import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once
from vllm.utils import print_warning_once
logger = init_logger(__name__)
@ -72,7 +72,7 @@ class CustomOp(nn.Module):
if not enabled:
return self.forward_native
if is_hip():
if current_platform.is_rocm():
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu

View File

@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip, print_warning_once
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
class GPTQMarlinState(Enum):
@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
if current_platform.is_rocm():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.utils import is_hip
from vllm.platforms import current_platform
__all__ = ["CompressedTensorsW8A8Fp8"]
@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths=layer.logical_widths,
)
if is_hip():
if current_platform.is_rocm():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
if is_hip():
if current_platform.is_rocm():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,

View File

@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform
from vllm.utils import is_hip
logger = init_logger(__name__)
@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight = layer.weight
if is_hip():
if current_platform.is_rocm():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,

View File

@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"]
@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
if current_platform.is_rocm():
self.use_marlin = False
def create_weights(
@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale = layer.weight_scale
# If rocm, use float8_e4m3fnuz.
if is_hip():
if current_platform.is_rocm():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn
if current_platform.is_rocm() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
if current_platform.is_rocm():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(

View File

@ -4,16 +4,16 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import is_hip
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \
if current_platform.is_rocm() else None
def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
if is_hip():
if current_platform.is_rocm():
return False
capability_tuple = current_platform.get_device_capability()

View File

@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn
if is_hip():
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting

View File

@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting

View File

@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
@ -423,7 +423,7 @@ class LlamaModel(nn.Module):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if is_hip():
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting

View File

@ -12,7 +12,7 @@ import cloudpickle
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
from vllm.platforms import current_platform
from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp)
@ -247,7 +247,7 @@ def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
if is_hip():
if current_platform.is_rocm():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")

View File

@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting

View File

@ -314,10 +314,6 @@ class PyObjectCache:
self._index = 0
def is_hip() -> bool:
return torch.version.hip is not None
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
if not torch.cuda._is_compiled():
return 0
if is_hip():
if current_platform.is_rocm():
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = torch.cuda._device_count_amdsmi() if (hasattr(

View File

@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available,
flatten_2d_lists, is_pin_memory_available,
supports_dynamo, weak_ref_tensor)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.kv_cache_dtype == "fp8" and is_hip():
if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.