[CI/Build] fix flaky test (#3602)
This commit is contained in:
parent
42bc386129
commit
837e185142
@ -1,20 +1,16 @@
|
||||
import random
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
|
||||
|
||||
def get_aligned_size(batch_size: int, alignment: int):
|
||||
return ((batch_size + alignment - 1) // alignment * alignment)
|
||||
|
||||
|
||||
def test_prepare_prompt():
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_prompt(batch_size):
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
batch_size = random.randint(1, 256)
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
block_tables = {0: [1]}
|
||||
@ -111,7 +107,8 @@ def test_prepare_prompt():
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_prepare_decode_cuda_graph():
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_decode_cuda_graph(batch_size):
|
||||
model_config = ModelConfig(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph():
|
||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
batch_size = random.randint(1, 256)
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph():
|
||||
input_tokens, input_positions, input_metadata, _, _, _ = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert input_metadata.is_prompt is False
|
||||
assert input_metadata.prompt_lens is None
|
||||
assert input_metadata.num_prompt_tokens == 0
|
||||
assert input_metadata.num_generation_tokens == (get_aligned_size(
|
||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
|
||||
assert input_metadata.num_generation_tokens == expected_bs
|
||||
assert input_metadata.max_seq_len is None
|
||||
assert input_metadata.subquery_start_loc is None
|
||||
assert input_metadata.seq_start_loc is None
|
||||
@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph():
|
||||
assert input_metadata.use_cuda_graph is True
|
||||
assert input_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (get_aligned_size(
|
||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
|
||||
assert input_positions.shape == (get_aligned_size(
|
||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
|
||||
assert input_tokens.shape == (expected_bs, )
|
||||
assert input_positions.shape == (expected_bs, )
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
# Verify Sampling
|
||||
|
||||
Loading…
Reference in New Issue
Block a user