[CI/Build] Avoid CUDA initialization (#8534)

This commit is contained in:
Cyrus Leung 2024-09-18 18:38:11 +08:00 committed by GitHub
parent e351572900
commit 6ffa3f314c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 256 additions and 256 deletions

View File

@ -1,10 +1,10 @@
import random
import time import time
import torch import torch
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
seed_everything)
@torch.inference_mode() @torch.inference_mode()
@ -16,10 +16,7 @@ def main(num_tokens: int,
do_profile: bool = False, do_profile: bool = False,
num_warmup_iters: int = 5, num_warmup_iters: int = 5,
num_iters: int = 100) -> None: num_iters: int = 100) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda") torch.set_default_device("cuda")
layer = RMSNorm(hidden_size).to(dtype=dtype) layer = RMSNorm(hidden_size).to(dtype=dtype)

View File

@ -10,7 +10,7 @@ from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser, seed_everything
class BenchmarkConfig(TypedDict): class BenchmarkConfig(TypedDict):
@ -166,7 +166,7 @@ class BenchmarkWorker:
def __init__(self, seed: int) -> None: def __init__(self, seed: int) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(seed) seed_everything(seed)
self.seed = seed self.seed = seed
def benchmark( def benchmark(
@ -180,7 +180,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]: ) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(self.seed) seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype, dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8) use_fp8_w8a8=use_fp8_w8a8)

View File

@ -6,7 +6,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random) create_kv_caches_with_random, seed_everything)
NUM_BLOCKS = 1024 NUM_BLOCKS = 1024
PARTITION_SIZE = 512 PARTITION_SIZE = 512
@ -28,10 +28,7 @@ def main(
device: str = "cuda", device: str = "cuda",
kv_cache_dtype: Optional[str] = None, kv_cache_dtype: Optional[str] = None,
) -> None: ) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs, query = torch.empty(num_seqs,

View File

@ -1,10 +1,10 @@
import random
import time import time
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
seed_everything)
@torch.inference_mode() @torch.inference_mode()
@ -17,10 +17,7 @@ def main(num_tokens: int,
do_profile: bool = False, do_profile: bool = False,
num_warmup_iters: int = 5, num_warmup_iters: int = 5,
num_iters: int = 100) -> None: num_iters: int = 100) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda") torch.set_default_device("cuda")
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)

View File

@ -6,7 +6,7 @@ import torch
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope) get_rope)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser, seed_everything
def benchmark_rope_kernels_multi_lora( def benchmark_rope_kernels_multi_lora(
@ -22,9 +22,7 @@ def benchmark_rope_kernels_multi_lora(
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size

View File

@ -7,6 +7,7 @@ from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, QuickGELU, NewGELU, QuickGELU,
SiluAndMul) SiluAndMul)
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
@ -34,9 +35,7 @@ def test_act_and_mul(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype) x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu": if activation == "silu":
@ -77,9 +76,7 @@ def test_activation(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype) x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation[0]() layer = activation[0]()

View File

@ -6,7 +6,7 @@ import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
@ -139,10 +139,8 @@ def test_paged_attention(
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and head_size % 16: if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip() pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads num_query_heads, num_kv_heads = num_heads
@ -354,10 +352,7 @@ def test_paged_attention_rocm(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads num_query_heads, num_kv_heads = num_heads
@ -506,10 +501,7 @@ def test_multi_query_kv_attention(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use # As the xformers library is already tested with its own tests, we can use

View File

@ -45,7 +45,7 @@ def test_flash_attn(monkeypatch):
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]): with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.name != STR_FLASH_ATTN_VAL

View File

@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.quantization.awq_triton import ( from vllm.model_executor.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
from vllm.utils import seed_everything
device = "cuda" device = "cuda"
@ -79,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
zeros_cols = qweight_cols zeros_cols = qweight_cols
zeros_dtype = torch.int32 zeros_dtype = torch.int32
torch.manual_seed(0) seed_everything(0)
qweight = torch.randint(0, qweight = torch.randint(0,
torch.iinfo(torch.int32).max, torch.iinfo(torch.int32).max,
@ -133,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size):
qzeros_rows = scales_rows qzeros_rows = scales_rows
qzeros_cols = qweight_cols qzeros_cols = qweight_cols
torch.manual_seed(0) seed_everything(0)
input = torch.rand((input_rows, input_cols), input = torch.rand((input_rows, input_cols),
dtype=input_dtype, dtype=input_dtype,

View File

@ -7,7 +7,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import ( from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn) LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
@ -172,10 +172,7 @@ def test_paged_attention(
blocksparse_block_size: int, blocksparse_block_size: int,
blocksparse_head_sliding_step: int, blocksparse_head_sliding_step: int,
) -> None: ) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads num_query_heads, num_kv_heads = num_heads
@ -386,10 +383,7 @@ def test_varlen_blocksparse_attention_prefill(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use # As the xformers library is already tested with its own tests, we can use

View File

@ -6,6 +6,7 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import seed_everything
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -55,10 +56,7 @@ def test_copy_blocks(
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and head_size % 16: if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip() pytest.skip()
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
# Generate random block mappings where each source block is mapped to two # Generate random block mappings where each source block is mapped to two
# destination blocks. # destination blocks.
@ -134,10 +132,7 @@ def test_reshape_and_cache(
) -> None: ) -> None:
if kv_cache_dtype == "fp8" and head_size % 16: if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip() pytest.skip()
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
# Create a random slot mapping. # Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
@ -229,9 +224,7 @@ def test_reshape_and_cache_flash(
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
) -> None: ) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
# Create a random slot mapping. # Create a random slot mapping.
@ -345,10 +338,8 @@ def test_swap_blocks(
pytest.skip() pytest.skip()
if kv_cache_dtype == "fp8" and head_size % 16: if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip() pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
src_device = device if direction[0] == "cuda" else 'cpu' src_device = device if direction[0] == "cuda" else 'cpu'
dst_device = device if direction[1] == "cuda" else 'cpu' dst_device = device if direction[1] == "cuda" else 'cpu'
@ -417,9 +408,7 @@ def test_fp8_e4m3_conversion(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
random.seed(seed) seed_everything(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
low = -224.0 low = -224.0
high = 224.0 high = 224.0

View File

@ -7,6 +7,7 @@ from einops import rearrange
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update) causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything
def causal_conv1d_ref( def causal_conv1d_ref(
@ -104,7 +105,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
if itype == torch.bfloat16: if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2 rtol, atol = 1e-2, 5e-2
# set seed # set seed
torch.random.manual_seed(0) seed_everything(0)
if not channel_last: if not channel_last:
x = torch.randn(batch, x = torch.randn(batch,
4096 + dim + 64, 4096 + dim + 64,
@ -175,7 +176,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
if itype == torch.bfloat16: if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2 rtol, atol = 1e-2, 5e-2
# set seed # set seed
torch.random.manual_seed(0) seed_everything(0)
batch = 2 batch = 2
x = torch.randn(batch, dim, device=device, dtype=itype) x = torch.randn(batch, dim, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)

View File

@ -15,9 +15,6 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor): def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
@ -119,7 +116,7 @@ def cutlass_int8_gemm_helper(m: int,
@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89, @pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool): per_out_ch: bool, use_bias: bool):
@ -157,7 +154,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89, @pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype], out_dtype: Type[torch.dtype],
@ -175,7 +172,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(capability < 89, @pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
use_bias: bool, device: str): use_bias: bool, device: str):
@ -207,7 +204,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(capability < 89, @pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias: bool): use_bias: bool):

View File

@ -4,6 +4,7 @@ import pytest
import torch import torch
import vllm.attention.backends.flash_attn # noqa: F401 import vllm.attention.backends.flash_attn # noqa: F401
from vllm.utils import seed_everything
NUM_HEADS = [(4, 4), (8, 2), (16, 2)] NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
@ -87,7 +88,7 @@ def test_flash_attn_with_paged_kv(
num_blocks: int, num_blocks: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
num_query_heads = num_heads[0] num_query_heads = num_heads[0]
num_kv_heads = num_heads[1] num_kv_heads = num_heads[1]
@ -174,7 +175,7 @@ def test_varlen_with_paged_kv(
num_blocks: int, num_blocks: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]

View File

@ -4,6 +4,8 @@ import flashinfer
import pytest import pytest
import torch import torch
from vllm.utils import seed_everything
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
@ -82,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv(
soft_cap: Optional[float], soft_cap: Optional[float],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
num_query_heads = num_heads[0] num_query_heads = num_heads[0]
num_kv_heads = num_heads[1] num_kv_heads = num_heads[1]
@ -168,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_size: int, block_size: int,
soft_cap: Optional[float]) -> None: soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]
@ -266,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
head_size: int, dtype: torch.dtype, block_size: int, head_size: int, dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None: soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]
@ -379,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
) -> None: ) -> None:
# test doesn't work for num_heads = (16,16) # test doesn't work for num_heads = (16,16)
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0) seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
num_query_heads = num_heads[0] num_query_heads = num_heads[0]
num_kv_heads = num_heads[1] num_kv_heads = num_heads[1]

View File

@ -5,6 +5,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import (FP8_DTYPE, from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant) ref_dynamic_per_token_quant)
from vllm.utils import seed_everything
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
@ -24,8 +25,7 @@ SEEDS = [0]
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, scale_ub: bool, dtype: torch.dtype, scale_ub: bool,
seed: int) -> None: seed: int) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans device="cuda") + 1e-6 # avoid nans
@ -49,8 +49,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode() @torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None: dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
@ -67,8 +66,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
def test_fp8_quant_large(seed: int) -> None: def test_fp8_quant_large(seed: int) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
torch.cuda.manual_seed(seed)
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
hidden_size = 1152 # Smallest hidden_size to reproduce the error hidden_size = 1152 # Smallest hidden_size to reproduce the error

View File

@ -7,6 +7,7 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.utils import seed_everything
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
@ -74,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
@torch.inference_mode() @torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype, def test_mmvq(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType): quant_type: GGMLQuantizationType):
torch.cuda.manual_seed_all(0) seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type) tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
@ -110,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype,
@torch.inference_mode() @torch.inference_mode()
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType): quant_type: GGMLQuantizationType):
torch.cuda.manual_seed_all(0) seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type) tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")

View File

@ -4,6 +4,7 @@ import torch
from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm._custom_ops import scaled_int8_quant from vllm._custom_ops import scaled_int8_quant
from vllm.utils import seed_everything
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
@ -44,8 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
@torch.inference_mode() @torch.inference_mode()
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None: dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
@ -68,8 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode() @torch.inference_mode()
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None: dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8) int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, x = torch.rand(num_tokens, hidden_size, dtype=dtype,
@ -113,8 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int, dtype: torch.dtype, seed: int,
scale: float) -> None: scale: float) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8) int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
@ -140,8 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int, dtype: torch.dtype, seed: int,
scale: float, azp: int) -> None: scale: float, azp: int) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8) int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, x = torch.rand(num_tokens, hidden_size, dtype=dtype,

View File

@ -3,6 +3,7 @@ import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils import seed_everything
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
@ -30,9 +31,7 @@ def test_rms_norm(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype) layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1) layer.weight.data.normal_(mean=1.0, std=0.1)

View File

@ -48,7 +48,7 @@ WTYPE_ZEROPOINTS = [
# `is_quant_method_supported` conflates kernels with quantization methods # `is_quant_method_supported` conflates kernels with quantization methods
# an assumption which is breaking down as quantizations methods can have # an assumption which is breaking down as quantizations methods can have
# have kernels and some kernels support multiple quantization methods. # have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
def rand_data(shape, dtype=torch.float16): def rand_data(shape, dtype=torch.float16):

View File

@ -5,6 +5,7 @@ from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update) selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
def selective_state_update_ref(state, def selective_state_update_ref(state,
@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
rtolw = max(rtolw, rtol) rtolw = max(rtolw, rtol)
atolw = max(atolw, atol) atolw = max(atolw, atol)
# set seed # set seed
torch.random.manual_seed(0) seed_everything(0)
batch_size = 2 batch_size = 2
dim = 4 dim = 4
dstate = 8 dstate = 8
@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
if torch.version.hip: if torch.version.hip:
atol *= 2 atol *= 2
# set seed # set seed
torch.random.manual_seed(0) seed_everything(0)
batch_size = 1 batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype) x = torch.randn(batch_size, dim, device=device, dtype=itype)

View File

@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize) marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import seed_everything
def torch_moe(a, w1, w2, score, topk): def torch_moe(a, w1, w2, score, topk):
@ -151,7 +152,7 @@ def test_fused_marlin_moe(
act_order: bool, act_order: bool,
num_bits: int, num_bits: int,
): ):
torch.manual_seed(7) seed_everything(7)
if topk > e: if topk > e:
return return

View File

@ -5,6 +5,7 @@ import pytest
import torch import torch
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
@ -46,9 +47,8 @@ def test_rotary_embedding(
) -> None: ) -> None:
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
torch.random.manual_seed(seed)
if torch.cuda.is_available(): seed_everything(seed)
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
@ -100,9 +100,7 @@ def test_batched_rotary_embedding(
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora(
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size

View File

@ -9,7 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64] NUM_QUERIES_PER_KV = [1, 8, 64]
@ -39,10 +39,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
random.seed(0) seed_everything(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
random.seed(0) seed_everything(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process

View File

@ -39,6 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.utils import seed_everything
from .utils import DummyLoRAManager from .utils import DummyLoRAManager
@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len) -> None: seq_len) -> None:
dtype = torch.float16 dtype = torch.float16
seed = 0 seed = 0
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8

View File

@ -4,7 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64]. is set to [1, 2, 4, 8, 16, 32, 64].
""" """
import random
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm) ref_torch_groupgemm)
@ -145,11 +145,8 @@ def test_punica_sgmv(
seed: int, seed: int,
device: str, device: str,
): ):
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 seq_length = 128
( (
@ -238,11 +235,8 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 1 seq_length = 1
( (
@ -329,11 +323,9 @@ def test_punica_expand_nslices(
): ):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 128 if op_type == "sgmv" else 1
( (
inputs_tensor, inputs_tensor,

View File

@ -3,7 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and under different conditions, including various batches, numbers of LoRA , and
maximum ranks. maximum ranks.
""" """
import random
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm) ref_torch_groupgemm)
@ -60,11 +60,8 @@ def test_punica_sgmv(
seed: int, seed: int,
device: str, device: str,
): ):
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 seq_length = 128
( (
@ -153,11 +150,8 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 1 seq_length = 1
( (
@ -244,11 +238,9 @@ def test_punica_expand_nslices(
): ):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
random.seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 128 if op_type == "sgmv" else 1
( (
inputs_tensor, inputs_tensor,

View File

@ -2,23 +2,18 @@
Run `pytest tests/models/test_granite.py`. Run `pytest tests/models/test_granite.py`.
""" """
import importlib.metadata
import pytest import pytest
import transformers
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
TRANSFORMERS_VERSION = tuple(
map(int,
importlib.metadata.version("transformers").split(".")))
MODELS = [ MODELS = [
"ibm/PowerLM-3b", "ibm/PowerLM-3b",
] ]
# GraniteForCausalLM will be in transformers >= 4.45 # GraniteForCausalLM will be in transformers >= 4.45
@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), @pytest.mark.skipif(transformers.__version__ < "4.45",
reason="granite model test requires transformers >= 4.45") reason="granite model test requires transformers >= 4.45")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])

View File

@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert attn._k_scale == 1.0 assert attn._k_scale == 1.0
assert attn._v_scale == 1.0 assert attn._v_scale == 1.0
capability = current_platform.get_device_capability() if current_platform.has_device_capability(89) and not force_marlin:
capability = capability[0] * 10 + capability[1]
if capability >= 89 and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8 # For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn assert fc1.weight.dtype == torch.float8_e4m3fn
else: else:

View File

@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool:
return False return False
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] assert capability is not None
return (capability >=
QUANTIZATION_METHODS[quant_method].get_min_capability()) min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability()
return capability.to_int() >= min_capability

View File

@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -299,7 +300,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either # either
if torch.cuda.get_device_capability()[0] != 9: if not current_platform.has_device_capability(90):
self.use_naive_attn = True self.use_naive_attn = True
else: else:
try: try:

View File

@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step, from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask) get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
and current_platform.get_device_capability()[0] >= 8)
if IS_COMPUTE_8_OR_ABOVE: if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
use_spda = is_hip() or is_cpu() or not \ use_spda = is_hip() or is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device() device = device or (torch.cuda.current_device()
if torch.cuda.is_available() else "cpu") if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device) device = torch.device(device)
# NOTE: vllm CPU backend support BF16 instead of FP16. # NOTE: vllm CPU backend support BF16 instead of FP16.
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE

View File

@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0":
alibi_slopes=None, alibi_slopes=None,
sliding_window=None): sliding_window=None):
cap = current_platform.get_device_capability() BLOCK = 128 if current_platform.has_device_capability(80) else 64
BLOCK = 128 if cap[0] >= 8 else 64
NUM_WARPS = 8 NUM_WARPS = 8
# need to reduce num. blocks when using fp32 # need to reduce num. blocks when using fp32

View File

@ -203,7 +203,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
if current_platform.get_device_capability()[0] != 9: if not current_platform.has_device_capability(90):
# not Instinct series GPUs. # not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
@ -212,7 +212,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN: if selected_backend == _Backend.FLASH_ATTN:
if current_platform.get_device_capability()[0] < 8: if not current_platform.has_device_capability(80):
# Volta and Turing NVIDIA GPUs. # Volta and Turing NVIDIA GPUs.
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing " "Cannot use FlashAttention-2 backend for Volta and Turing "

View File

@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_cpu, is_hip, is_neuron, is_openvino, is_xpu, is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once) print_warning_once)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1035,20 +1035,20 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None: def __init__(self, device: str = "auto") -> None:
if device == "auto": if device == "auto":
# Automated device type detection # Automated device type detection
if is_neuron(): if current_platform.is_cuda_alike():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_openvino(): elif is_openvino():
self.device_type = "openvino" self.device_type = "openvino"
elif current_platform.is_tpu(): elif current_platform.is_tpu():
self.device_type = "tpu" self.device_type = "tpu"
elif is_cpu(): elif current_platform.is_cpu():
self.device_type = "cpu" self.device_type = "cpu"
elif is_xpu(): elif is_xpu():
self.device_type = "xpu" self.device_type = "xpu"
else: else:
# We don't call torch.cuda.is_available() here to raise RuntimeError("Failed to infer device type")
# avoid initializing CUDA before workers are forked
self.device_type = "cuda"
else: else:
# Device type is assigned explicitly # Device type is assigned explicitly
self.device_type = device self.device_type = device

View File

@ -35,6 +35,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
@dataclass @dataclass
@ -191,7 +192,7 @@ class GroupCoordinator:
assert self.cpu_group is not None assert self.cpu_group is not None
assert self.device_group is not None assert self.device_group is not None
if torch.cuda.is_available(): if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}") self.device = torch.device(f"cuda:{local_rank}")
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")

View File

@ -60,6 +60,7 @@ if TYPE_CHECKING:
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_PLUGINS: Optional[List[str]] = None VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False

View File

@ -116,10 +116,10 @@ class CompressedTensorsConfig(QuantizationConfig):
def _check_scheme_supported(self, def _check_scheme_supported(self,
min_capability: int, min_capability: int,
error: bool = True) -> bool: error: bool = True) -> bool:
capability = current_platform.get_device_capability() # type: ignore capability_tuple = current_platform.get_device_capability()
if capability is not None: if capability_tuple is not None:
capability = capability[0] * 10 + capability[1] capability = capability_tuple.to_int()
supported = capability >= min_capability supported = capability >= min_capability
if error and not supported: if error and not supported:
raise RuntimeError( raise RuntimeError(

View File

@ -32,9 +32,7 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability() self.use_marlin = not current_platform.has_device_capability(89)
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:

View File

@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability() self.use_marlin = (not current_platform.has_device_capability(89)
capability = capability[0] * 10 + capability[1] or envs.VLLM_TEST_FORCE_FP8_MARLIN)
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm # Disable marlin for rocm
if is_hip(): if is_hip():
self.use_marlin = False self.use_marlin = False

View File

@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None device_capability: Optional[int] = None
): ):
if device_capability is None: if device_capability is None:
major, minor = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = major * 10 + minor device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
if device_capability < 80: if device_capability < 80:
return [] return []
@ -52,8 +53,9 @@ def _check_marlin_supported(
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None: if device_capability is None:
major, minor = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = major * 10 + minor device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
supported_types = query_marlin_supported_quant_types( supported_types = query_marlin_supported_quant_types(
has_zp, device_capability) has_zp, device_capability)

View File

@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def is_fp8_marlin_supported(): def is_fp8_marlin_supported():
capability = current_platform.get_device_capability() return current_platform.has_device_capability(80)
return capability[0] >= 8
def apply_fp8_marlin_linear( def apply_fp8_marlin_linear(

View File

@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm # cutlass is not supported on Rocm
if is_hip(): if is_hip():
return False return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return ops.cutlass_scaled_mm_supports_fp8(capability) return ops.cutlass_scaled_mm_supports_fp8(capability)

View File

@ -97,10 +97,10 @@ def _get_quantization_config(
"""Get the quantization config.""" """Get the quantization config."""
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(model_config, load_config)
capability = current_platform.get_device_capability() # type: ignore capability_tuple = current_platform.get_device_capability()
if capability is not None: if capability_tuple is not None:
capability = capability[0] * 10 + capability[1] capability = capability_tuple.to_int()
if capability < quant_config.get_min_capability(): if capability < quant_config.get_min_capability():
raise ValueError( raise ValueError(
f"The quantization method {model_config.quantization} " f"The quantization method {model_config.quantization} "

View File

@ -207,7 +207,7 @@ class Qwen2VisionAttention(nn.Module):
selected_backend = backend_name_to_enum(backend_by_env_var) selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None: if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead. # For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8 device_available = current_platform.has_device_capability(80)
if device_available: if device_available:
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available

View File

@ -1,17 +1,13 @@
"""Utils for model executor.""" """Utils for model executor."""
import random
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np
import torch import torch
from vllm.utils import seed_everything
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
random.seed(seed) seed_everything(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def set_weight_attrs( def set_weight_attrs(

View File

@ -6,10 +6,10 @@ from .interface import Platform, PlatformEnum
class CpuPlatform(Platform): class CpuPlatform(Platform):
_enum = PlatformEnum.CPU _enum = PlatformEnum.CPU
@staticmethod @classmethod
def get_device_name(device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return "cpu" return "cpu"
@staticmethod @classmethod
def inference_mode(): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()

View File

@ -11,7 +11,7 @@ from typing_extensions import ParamSpec
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class CudaPlatform(Platform): class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA _enum = PlatformEnum.CUDA
@staticmethod @classmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
physical_device_id = device_id_to_physical_device_id(device_id) physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id) major, minor = get_physical_device_capability(physical_device_id)
return DeviceCapability(major=major, minor=minor)
@staticmethod @classmethod
def get_device_name(device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id) physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id) return get_physical_device_name(physical_device_id)
@staticmethod @classmethod
@with_nvml_context @with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool: def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
""" """
query if the set of gpus are fully connected by nvlink (1 hop) query if the set of gpus are fully connected by nvlink (1 hop)
""" """

View File

@ -1,5 +1,5 @@
import enum import enum
from typing import Optional, Tuple from typing import NamedTuple, Optional, Tuple, Union
import torch import torch
@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum):
UNSPECIFIED = enum.auto() UNSPECIFIED = enum.auto()
class DeviceCapability(NamedTuple):
major: int
minor: int
def as_version_str(self) -> str:
return f"{self.major}.{self.minor}"
def to_int(self) -> int:
"""
Express device capability as an integer ``<major><minor>``.
It is assumed that the minor version is always a single digit.
"""
assert 0 <= self.minor < 10
return self.major * 10 + self.minor
class Platform: class Platform:
_enum: PlatformEnum _enum: PlatformEnum
@ -27,16 +44,47 @@ class Platform:
def is_cpu(self) -> bool: def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU return self._enum == PlatformEnum.CPU
@staticmethod def is_cuda_alike(self) -> bool:
def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]: """Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def get_device_capability(
cls,
device_id: int = 0,
) -> Optional[DeviceCapability]:
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
return None return None
@staticmethod @classmethod
def get_device_name(device_id: int = 0) -> str: def has_device_capability(
cls,
capability: Union[Tuple[int, int], int],
device_id: int = 0,
) -> bool:
"""
Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be:
- A tuple ``(major, minor)``.
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:
return False
if isinstance(capability, tuple):
return current_capability >= capability
return current_capability.to_int() >= capability
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError raise NotImplementedError
@staticmethod @classmethod
def inference_mode(): def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`. """A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU This wrapper is recommended because some hardware backends such as TPU

View File

@ -1,12 +1,11 @@
import os import os
from functools import lru_cache from functools import lru_cache
from typing import Tuple
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
@ -20,12 +19,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
@staticmethod @classmethod
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
return torch.cuda.get_device_capability(device_id) major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@staticmethod @classmethod
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_name(device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id) return torch.cuda.get_device_name(device_id)

View File

@ -6,6 +6,10 @@ from .interface import Platform, PlatformEnum
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU
@staticmethod @classmethod
def inference_mode(): def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def inference_mode(cls):
return torch.no_grad() return torch.no_grad()

View File

@ -8,13 +8,15 @@ from huggingface_hub import file_exists, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import EntryNotFoundError
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
from vllm.platforms import current_platform
WEIGHTS_NAME = "adapter_model.bin" WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
# Get current device name based on available devices # Get current device name based on available devices
def infer_device() -> str: def infer_device() -> str:
if torch.cuda.is_available(): if current_platform.is_cuda_alike():
return "cuda" return "cuda"
return "cpu" return "cpu"

View File

@ -17,6 +17,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.platforms import current_platform
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
_config_home = envs.VLLM_CONFIG_ROOT _config_home = envs.VLLM_CONFIG_ROOT
@ -151,7 +152,7 @@ class UsageMessage:
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, Any]) -> None: extra_kvs: Dict[str, Any]) -> None:
# Platform information # Platform information
if torch.cuda.is_available(): if current_platform.is_cuda_alike():
device_property = torch.cuda.get_device_properties(0) device_property = torch.cuda.get_device_properties(0)
self.gpu_count = torch.cuda.device_count() self.gpu_count = torch.cuda.device_count()
self.gpu_type = device_property.name self.gpu_type = device_property.name

View File

@ -5,6 +5,7 @@ import datetime
import enum import enum
import gc import gc
import os import os
import random
import socket import socket
import subprocess import subprocess
import sys import sys
@ -32,6 +33,7 @@ from typing_extensions import ParamSpec, TypeIs, assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -373,6 +375,22 @@ def get_cpu_memory() -> int:
return psutil.virtual_memory().total return psutil.virtual_memory().total
def seed_everything(seed: int) -> None:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
if current_platform.is_cuda_alike():
torch.cuda.manual_seed_all(seed)
if is_xpu():
torch.xpu.manual_seed_all(seed)
def random_uuid() -> str: def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
@ -634,9 +652,7 @@ def create_kv_caches_with_random_flash(
seed: int = 0, seed: int = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
@ -678,9 +694,7 @@ def create_kv_caches_with_random(
f"Does not support key cache of type fp8 with head_size {head_size}" f"Does not support key cache of type fp8 with head_size {head_size}"
) )
torch.random.manual_seed(seed) seed_everything(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
@ -750,7 +764,7 @@ class CudaMemoryProfiler:
def current_memory_usage(self) -> float: def current_memory_usage(self) -> float:
# Return the memory usage in bytes. # Return the memory usage in bytes.
if torch.cuda.is_available(): if current_platform.is_cuda_alike():
torch.cuda.reset_peak_memory_stats(self.device) torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device) mem = torch.cuda.max_memory_allocated(self.device)
elif is_xpu(): elif is_xpu():

View File

@ -454,14 +454,20 @@ def init_worker_distributed_environment(
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: if torch_dtype == torch.bfloat16: # noqa: SIM102
compute_capability = current_platform.get_device_capability() if not current_platform.has_device_capability(80):
if compute_capability[0] < 8: capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name() gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError( raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability " "Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
f"{compute_capability[0]}.{compute_capability[1]}. "
"You can use float16 instead by explicitly setting the" "You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.") "`dtype` flag in CLI, for example: --dtype=half.")