Fix tests in test_chunked_prefill_scheduler which fail with BlockManager V2 (#8752)

This commit is contained in:
sroy745 2024-09-24 21:26:36 -07:00 committed by GitHub
parent b4522474a3
commit fc3afc20df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,16 +27,19 @@ def schedule_and_update_computed_tokens(scheduler):
return metas, out return metas, out
def test_simple(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_simple(use_v2_block_manager: bool):
"""Verify basic scheduling works.""" """Verify basic scheduling works."""
block_size = 4 block_size = 4
num_seq_group = 4 num_seq_group = 4
max_model_len = 16 max_model_len = 16
max_num_batched_tokens = 64 max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
num_seq_group, max_num_batched_tokens,
max_model_len, num_seq_group,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 8
@ -45,7 +48,9 @@ def test_simple():
# Add seq groups to scheduler. # Add seq groups to scheduler.
for i in range(num_seq_group): for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) _, seq_group = create_dummy_prompt(str(i),
prompt_length=block_size,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
@ -69,30 +74,36 @@ def test_simple():
assert len(seq_group_meta) == num_seq_group assert len(seq_group_meta) == num_seq_group
def test_chunk(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_chunk(use_v2_block_manager: bool):
"""Verify prefills are chunked properly.""" """Verify prefills are chunked properly."""
block_size = 4 block_size = 4
max_seqs = 60 max_seqs = 60
max_model_len = 80 max_model_len = 80
max_num_batched_tokens = 64 max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = [] running: List[SequenceGroup] = []
# Add seq groups to scheduler. # Add seq groups to scheduler.
for i in range(2): for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
# Verify the second request is chunked. # Verify the second request is chunked.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
print()
assert set(get_sequence_groups(out)) == set(running) assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 60 assert seq_group_meta[0].token_chunk_size == 60
# Verify it is chunked. # Verify it is chunked.
@ -113,24 +124,29 @@ def test_chunk():
assert out.num_batched_tokens == 57 assert out.num_batched_tokens == 57
def test_complex(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_complex(use_v2_block_manager: bool):
block_size = 4 block_size = 4
max_seqs = 60 max_seqs = 60
max_model_len = 80 max_model_len = 80
max_num_batched_tokens = 64 max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 64
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 64
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = [] running: List[SequenceGroup] = []
# Add seq groups to scheduler. # Add seq groups to scheduler.
for i in range(2): for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
assert seq_group.is_prefill() assert seq_group.is_prefill()
@ -151,7 +167,9 @@ def test_complex():
# Add 2 more requests. # Add 2 more requests.
for i in range(2, 4): for i in range(2, 4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
@ -176,16 +194,19 @@ def test_complex():
assert running[2].is_prefill() assert running[2].is_prefill()
def test_maximal_decoding(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_maximal_decoding(use_v2_block_manager: bool):
"""Verify decoding requests are prioritized.""" """Verify decoding requests are prioritized."""
block_size = 4 block_size = 4
max_seqs = 2 max_seqs = 2
max_model_len = 8 max_model_len = 8
max_num_batched_tokens = 2 max_num_batched_tokens = 2
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 8
@ -194,7 +215,9 @@ def test_maximal_decoding():
# Add seq groups to scheduler. # Add seq groups to scheduler.
for i in range(2): for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=2) _, seq_group = create_dummy_prompt(str(i),
prompt_length=2,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
assert seq_group.is_prefill() assert seq_group.is_prefill()
@ -211,7 +234,9 @@ def test_maximal_decoding():
append_new_token(running[0], 1) append_new_token(running[0], 1)
# Create one more seq_group. # Create one more seq_group.
_, seq_group = create_dummy_prompt("3", prompt_length=2) _, seq_group = create_dummy_prompt("3",
prompt_length=2,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
assert seq_group.is_prefill() assert seq_group.is_prefill()
@ -263,23 +288,28 @@ def test_maximal_decoding():
assert out.num_batched_tokens == 2 assert out.num_batched_tokens == 2
def test_prompt_limit(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prompt_limit(use_v2_block_manager: bool):
"""Verify max_num_batched_tokens < max_model_len is possible.""" """Verify max_num_batched_tokens < max_model_len is possible."""
block_size = 4 block_size = 4
max_seqs = 32 max_seqs = 32
max_model_len = 64 max_model_len = 64
max_num_batched_tokens = 32 max_num_batched_tokens = 32
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 16
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = [] running: List[SequenceGroup] = []
_, seq_group = create_dummy_prompt("1", prompt_length=48) _, seq_group = create_dummy_prompt("1",
prompt_length=48,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
assert seq_group.is_prefill() assert seq_group.is_prefill()
@ -293,7 +323,8 @@ def test_prompt_limit():
assert out.num_batched_tokens == 32 assert out.num_batched_tokens == 32
def test_prompt_limit_exceed(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_prompt_limit_exceed(use_v2_block_manager: bool):
block_size = 4 block_size = 4
max_seqs = 64 max_seqs = 64
max_model_len = 32 max_model_len = 32
@ -303,12 +334,13 @@ def test_prompt_limit_exceed():
max_model_len, max_model_len,
enable_chunked_prefill=True) enable_chunked_prefill=True)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 16
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = [] running: List[SequenceGroup] = []
_, seq_group = create_dummy_prompt("2",
_, seq_group = create_dummy_prompt("2", prompt_length=48) prompt_length=48,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
assert seq_group.is_prefill() assert seq_group.is_prefill()
@ -317,22 +349,28 @@ def test_prompt_limit_exceed():
assert out.ignored_seq_groups[0] == seq_group assert out.ignored_seq_groups[0] == seq_group
def test_swap(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_swap(use_v2_block_manager: bool):
"""Verify swapping works with chunked prefill requests""" """Verify swapping works with chunked prefill requests"""
block_size = 4 block_size = 4
max_seqs = 30 max_seqs = 30
max_model_len = 200 max_model_len = 200
max_num_batched_tokens = 30 max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 16
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1",
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked. # The request is chunked.
@ -369,21 +407,27 @@ def test_swap():
assert out.blocks_to_swap_out == [] assert out.blocks_to_swap_out == []
def test_running_prefill_prioritized_over_swap(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool):
block_size = 4 block_size = 4
max_seqs = 30 max_seqs = 30
max_model_len = 200 max_model_len = 200
max_num_batched_tokens = 30 max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1",
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked. # The request is chunked.
@ -413,7 +457,9 @@ def test_running_prefill_prioritized_over_swap():
scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
_, seq_group2 = create_dummy_prompt("2", prompt_length=60) _, seq_group2 = create_dummy_prompt("2",
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group2) scheduler.add_seq_group(seq_group2)
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1 assert len(out.scheduled_seq_groups) == 1
@ -455,22 +501,27 @@ def test_running_prefill_prioritized_over_swap():
assert out.blocks_to_swap_out == [] assert out.blocks_to_swap_out == []
def test_chunked_prefill_preempt(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_chunked_prefill_preempt(use_v2_block_manager: bool):
"""Verify preempt works with chunked prefill requests""" """Verify preempt works with chunked prefill requests"""
block_size = 4 block_size = 4
max_seqs = 30 max_seqs = 30
max_model_len = 200 max_model_len = 200
max_num_batched_tokens = 30 max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 16
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
_, seq_group = create_dummy_prompt("1", prompt_length=60) _, seq_group = create_dummy_prompt("1",
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked. # The request is chunked.
@ -517,22 +568,27 @@ def test_chunked_prefill_preempt():
assert out.num_batched_tokens == max_num_batched_tokens assert out.num_batched_tokens == max_num_batched_tokens
def test_chunked_prefill_max_seqs(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_chunked_prefill_max_seqs(use_v2_block_manager: bool):
block_size = 4 block_size = 4
max_seqs = 2 max_seqs = 2
max_model_len = 80 max_model_len = 80
max_num_batched_tokens = 64 max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 128
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 128
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = [] running: List[SequenceGroup] = []
_, seq_group = create_dummy_prompt("1", prompt_length=65) _, seq_group = create_dummy_prompt("1",
prompt_length=65,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
# The first prefill is chunked. # The first prefill is chunked.
@ -542,7 +598,9 @@ def test_chunked_prefill_max_seqs():
# Add new requests. # Add new requests.
for i in range(4): for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=65) _, seq_group = create_dummy_prompt(str(i),
prompt_length=65,
block_size=block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
running.append(seq_group) running.append(seq_group)
@ -564,16 +622,19 @@ def test_chunked_prefill_max_seqs():
assert not running[1].is_prefill() assert not running[1].is_prefill()
def test_perfix_caching(): @pytest.mark.parametrize('use_v2_block_manager', [True, False])
def test_perfix_caching(use_v2_block_manager: bool):
"""Verify allocating full blocks when prefix caching is enabled.""" """Verify allocating full blocks when prefix caching is enabled."""
block_size = 4 block_size = 4
max_seqs = 10 max_seqs = 10
max_model_len = 80 max_model_len = 80
max_num_batched_tokens = 64 max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens, scheduler_config = SchedulerConfig(
max_seqs, max_num_batched_tokens,
max_model_len, max_seqs,
enable_chunked_prefill=True) max_model_len,
enable_chunked_prefill=True,
use_v2_block_manager=use_v2_block_manager)
cache_config = CacheConfig(block_size, cache_config = CacheConfig(block_size,
1.0, 1.0,
1, 1,