[CI/Build] fix flaky test (#3602)

This commit is contained in:
youkaichao 2024-03-24 17:43:05 -07:00 committed by GitHub
parent 42bc386129
commit 837e185142
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,20 +1,16 @@
import random import pytest
import torch import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
def get_aligned_size(batch_size: int, alignment: int): @pytest.mark.parametrize("batch_size", list(range(1, 257)))
return ((batch_size + alignment - 1) // alignment * alignment) def test_prepare_prompt(batch_size):
def test_prepare_prompt():
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(None, None, None, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = [] prompt_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
block_tables = {0: [1]} block_tables = {0: [1]}
@ -111,7 +107,8 @@ def test_prepare_prompt():
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
def test_prepare_decode_cuda_graph(): @pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
model_config = ModelConfig( model_config = ModelConfig(
"facebook/opt-125m", "facebook/opt-125m",
"facebook/opt-125m", "facebook/opt-125m",
@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph():
model_runner = ModelRunner(model_config, None, None, None, None) model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = [] prompt_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
for i in range(batch_size): for i in range(batch_size):
@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph():
input_tokens, input_positions, input_metadata, _, _, _ = ( input_tokens, input_positions, input_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list)) model_runner._prepare_decode(seq_group_metadata_list))
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert input_metadata.is_prompt is False assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0 assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == (get_aligned_size( assert input_metadata.num_generation_tokens == expected_bs
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
assert input_metadata.max_seq_len is None assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None assert input_metadata.seq_start_loc is None
@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph():
assert input_metadata.use_cuda_graph is True assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto" assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (get_aligned_size( assert input_tokens.shape == (expected_bs, )
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), ) assert input_positions.shape == (expected_bs, )
assert input_positions.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
# Verify Sampling # Verify Sampling