[CI/Build] fix flaky test (#3602)

This commit is contained in:
youkaichao 2024-03-24 17:43:05 -07:00 committed by GitHub
parent 42bc386129
commit 837e185142
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,20 +1,16 @@
import random
import pytest
import torch
from vllm.config import ModelConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
def get_aligned_size(batch_size: int, alignment: int):
return ((batch_size + alignment - 1) // alignment * alignment)
def test_prepare_prompt():
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
model_runner = ModelRunner(None, None, None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
block_tables = {0: [1]}
@ -111,7 +107,8 @@ def test_prepare_prompt():
torch.testing.assert_close(actual, expected)
def test_prepare_decode_cuda_graph():
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph():
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
for i in range(batch_size):
@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph():
input_tokens, input_positions, input_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
assert input_metadata.num_generation_tokens == expected_bs
assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None
@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph():
assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
assert input_positions.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
assert input_tokens.shape == (expected_bs, )
assert input_positions.shape == (expected_bs, )
torch.testing.assert_close(input_tokens, input_positions)
# Verify Sampling