[CI/Build] fix flaky test (#3602)
This commit is contained in:
parent
42bc386129
commit
837e185142
@ -1,20 +1,16 @@
|
|||||||
import random
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
|
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||||
|
|
||||||
|
|
||||||
def get_aligned_size(batch_size: int, alignment: int):
|
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||||
return ((batch_size + alignment - 1) // alignment * alignment)
|
def test_prepare_prompt(batch_size):
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_prompt():
|
|
||||||
model_runner = ModelRunner(None, None, None, None, None)
|
model_runner = ModelRunner(None, None, None, None, None)
|
||||||
model_runner.set_block_size(16)
|
model_runner.set_block_size(16)
|
||||||
|
|
||||||
batch_size = random.randint(1, 256)
|
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
block_tables = {0: [1]}
|
block_tables = {0: [1]}
|
||||||
@ -111,7 +107,8 @@ def test_prepare_prompt():
|
|||||||
torch.testing.assert_close(actual, expected)
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_decode_cuda_graph():
|
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||||
|
def test_prepare_decode_cuda_graph(batch_size):
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph():
|
|||||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||||
model_runner.set_block_size(16)
|
model_runner.set_block_size(16)
|
||||||
|
|
||||||
batch_size = random.randint(1, 256)
|
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph():
|
|||||||
input_tokens, input_positions, input_metadata, _, _, _ = (
|
input_tokens, input_positions, input_metadata, _, _, _ = (
|
||||||
model_runner._prepare_decode(seq_group_metadata_list))
|
model_runner._prepare_decode(seq_group_metadata_list))
|
||||||
|
|
||||||
|
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||||
# Verify input metadata is correct for prompts.
|
# Verify input metadata is correct for prompts.
|
||||||
device = model_runner.device
|
device = model_runner.device
|
||||||
assert input_metadata.is_prompt is False
|
assert input_metadata.is_prompt is False
|
||||||
assert input_metadata.prompt_lens is None
|
assert input_metadata.prompt_lens is None
|
||||||
assert input_metadata.num_prompt_tokens == 0
|
assert input_metadata.num_prompt_tokens == 0
|
||||||
assert input_metadata.num_generation_tokens == (get_aligned_size(
|
assert input_metadata.num_generation_tokens == expected_bs
|
||||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
|
|
||||||
assert input_metadata.max_seq_len is None
|
assert input_metadata.max_seq_len is None
|
||||||
assert input_metadata.subquery_start_loc is None
|
assert input_metadata.subquery_start_loc is None
|
||||||
assert input_metadata.seq_start_loc is None
|
assert input_metadata.seq_start_loc is None
|
||||||
@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph():
|
|||||||
assert input_metadata.use_cuda_graph is True
|
assert input_metadata.use_cuda_graph is True
|
||||||
assert input_metadata.kv_cache_dtype == "auto"
|
assert input_metadata.kv_cache_dtype == "auto"
|
||||||
|
|
||||||
assert input_tokens.shape == (get_aligned_size(
|
assert input_tokens.shape == (expected_bs, )
|
||||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
|
assert input_positions.shape == (expected_bs, )
|
||||||
assert input_positions.shape == (get_aligned_size(
|
|
||||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
|
|
||||||
torch.testing.assert_close(input_tokens, input_positions)
|
torch.testing.assert_close(input_tokens, input_positions)
|
||||||
|
|
||||||
# Verify Sampling
|
# Verify Sampling
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user