[Bugfix][fast] Fix the get_num_blocks_touched logic (#6849)
This commit is contained in:
parent
21b9c49aa3
commit
782e53ab59
@ -311,6 +311,68 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
|
||||
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [8])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [4])
|
||||
@pytest.mark.parametrize("num_lookahead_slots", [3, 8, 10])
|
||||
@pytest.mark.parametrize("enable_caching", [True, False])
|
||||
def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
|
||||
enable_caching):
|
||||
""" Verify the block manager can correctly determine if a sequence group
|
||||
can be swapped in/out.
|
||||
"""
|
||||
num_cpu_blocks = num_gpu_blocks
|
||||
block_manager = BlockSpaceManagerV2(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0,
|
||||
enable_caching=enable_caching)
|
||||
prompt, seq_group = create_dummy_prompt(
|
||||
"1", prompt_length=(num_gpu_blocks - 1) * block_size - 1)
|
||||
prompt.status = SequenceStatus.WAITING
|
||||
block_manager.allocate(seq_group)
|
||||
prompt.status = SequenceStatus.RUNNING
|
||||
|
||||
# Swap seq group from GPU -> CPU.
|
||||
gpu_blocks = block_manager.get_block_table(prompt)
|
||||
assert block_manager.can_swap_out(seq_group)
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_out(seq_group)
|
||||
mapping_keys = [key for key, _ in mapping]
|
||||
assert mapping_keys == gpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
|
||||
prompt.status = SequenceStatus.SWAPPED
|
||||
|
||||
# At this moment, we still have enough free blocks to swap in the seq group.
|
||||
if num_lookahead_slots <= block_size:
|
||||
assert block_manager.can_swap_in(seq_group,
|
||||
num_lookahead_slots) == AllocStatus.OK
|
||||
else:
|
||||
assert block_manager.can_swap_in(
|
||||
seq_group, num_lookahead_slots) == AllocStatus.NEVER
|
||||
|
||||
# During Swapped out, 2 cached blocks were evicted from the GPU,
|
||||
# so the prompt1 can't be swapped in
|
||||
prompt2_len = 2 * block_size - 1
|
||||
prompt2, seq_group2 = create_dummy_prompt(
|
||||
"2",
|
||||
prompt_length=prompt2_len,
|
||||
prompt_tokens=[10000 + i for i in range(prompt2_len)])
|
||||
prompt2.status = SequenceStatus.WAITING
|
||||
block_manager.allocate(seq_group2)
|
||||
|
||||
# Swap seq group from CPU -> GPU.
|
||||
if num_lookahead_slots <= block_size:
|
||||
assert block_manager.can_swap_in(
|
||||
seq_group, num_lookahead_slots) == AllocStatus.LATER
|
||||
else:
|
||||
assert block_manager.can_swap_in(
|
||||
seq_group, num_lookahead_slots) == AllocStatus.NEVER
|
||||
|
||||
|
||||
# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.
|
||||
|
||||
|
||||
|
||||
@ -100,3 +100,45 @@ class TestNaiveBlockAllocator:
|
||||
for i, block in enumerate(blocks):
|
||||
assert allocator.get_num_free_blocks() == i
|
||||
allocator.free(block)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [4])
|
||||
@pytest.mark.parametrize("block_size", [8])
|
||||
def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
|
||||
""" Verify the allocator can correctly return the number of
|
||||
blocks touched, with different lookahead slots.
|
||||
"""
|
||||
allocator_src = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocator_dst = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
|
||||
# Create a chain of cacheable blocks in the dst
|
||||
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
"immutable",
|
||||
allocator_src,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)))
|
||||
src_blocks = [allocate_block() for _ in range(num_blocks - 1)]
|
||||
|
||||
# All blocks are cached
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
src_blocks) == num_blocks - 1
|
||||
|
||||
# Insert one non-full block in the src
|
||||
allocate_non_full_block = \
|
||||
TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
"mutable", allocator_src,
|
||||
prev_block=src_blocks[-1],token_ids=[]
|
||||
)
|
||||
src_blocks.append(allocate_non_full_block())
|
||||
src_blocks[-1].append_token_ids([0])
|
||||
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
src_blocks, num_lookahead_slots=1) == num_blocks
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
src_blocks, num_lookahead_slots=block_size - 1) == num_blocks
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1)
|
||||
|
||||
@ -315,6 +315,60 @@ class TestPrefixCachingBlockAllocator:
|
||||
i)
|
||||
allocator.free(block)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [4])
|
||||
@pytest.mark.parametrize("block_size", [8])
|
||||
def test_prefix_caching_block_get_num_blocks_touched(
|
||||
num_blocks, block_size):
|
||||
""" Verify the allocator can correctly return the number of
|
||||
blocks touched, when there are cached prefixes and different
|
||||
lookahead slots.
|
||||
"""
|
||||
allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocator_dst = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
|
||||
# Create token ids that will exhaust all blocks except the last
|
||||
token_ids = list(range((num_blocks - 1) * block_size))
|
||||
|
||||
# Create a chain of cacheable blocks in the dst
|
||||
cached_blocks = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator_dst,
|
||||
)
|
||||
|
||||
# Create a chain of the same blocks in the src
|
||||
blocks_to_swap_in = \
|
||||
TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator_src,
|
||||
)
|
||||
|
||||
# All blocks are cached
|
||||
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 0
|
||||
|
||||
# Free the first block in the dst
|
||||
allocator_dst.free(cached_blocks[0])
|
||||
|
||||
# Now the first block becomes dangling, the swapped blocks need
|
||||
# to reclaim the first block in the dst
|
||||
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 1
|
||||
|
||||
# Insert one non-full block in the src
|
||||
non_full_block = allocator_src.allocate_mutable_block(
|
||||
blocks_to_swap_in[-1])
|
||||
non_full_block.append_token_ids([0])
|
||||
blocks_to_swap_in.append(non_full_block)
|
||||
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in,
|
||||
num_lookahead_slots=1) == 2
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
blocks_to_swap_in, num_lookahead_slots=block_size - 1) == 2
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
blocks_to_swap_in, num_lookahead_slots=block_size) == 3
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
|
||||
@ -15,13 +15,15 @@ def create_dummy_prompt(
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
use_beam_search: bool = False,
|
||||
best_of: int = 1,
|
||||
prompt_tokens: Optional[List[int]] = None,
|
||||
) -> Tuple[Sequence, SequenceGroup]:
|
||||
if not block_size:
|
||||
block_size = prompt_length
|
||||
|
||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||
# and prompt "0 ... block_size".
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
if prompt_tokens is None:
|
||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||
# and prompt "0 ... block_size".
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id),
|
||||
inputs={
|
||||
|
||||
@ -307,9 +307,8 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
# TODO(cade): make sure the logic is correct and clean it up.
|
||||
for block in blocks:
|
||||
if not block.is_full and num_lookahead_slots != 0:
|
||||
if block.num_empty_slots >= num_lookahead_slots:
|
||||
new_block_count += 1
|
||||
else:
|
||||
new_block_count += 1
|
||||
if num_lookahead_slots > block.num_empty_slots:
|
||||
new_block_count += cdiv(
|
||||
num_lookahead_slots - block.num_empty_slots,
|
||||
self._block_size)
|
||||
|
||||
@ -579,14 +579,17 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
num_touched_blocks = 0
|
||||
for block in blocks:
|
||||
if not block.is_full:
|
||||
if block.num_empty_slots >= num_lookahead_slots:
|
||||
num_touched_blocks += 1
|
||||
else:
|
||||
num_touched_blocks += 1
|
||||
if num_lookahead_slots > block.num_empty_slots:
|
||||
num_touched_blocks += cdiv(
|
||||
num_lookahead_slots - block.num_empty_slots,
|
||||
self._block_size)
|
||||
else:
|
||||
if not self.is_block_cached(block):
|
||||
# If the block has a match in the cache and the cached block
|
||||
# is not referenced, then we still count it as a touched block
|
||||
if not self.is_block_cached(block) or \
|
||||
(block.content_hash is not None and \
|
||||
self._cached_blocks[block.content_hash] in self.evictor):
|
||||
num_touched_blocks += 1
|
||||
return num_touched_blocks
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user