[Bugfix][fast] Fix the get_num_blocks_touched logic (#6849)

This commit is contained in:
Zach Zheng 2024-08-08 10:43:30 -07:00 committed by GitHub
parent 21b9c49aa3
commit 782e53ab59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 172 additions and 10 deletions

View File

@ -311,6 +311,68 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
@pytest.mark.parametrize("block_size", [8])
@pytest.mark.parametrize("num_gpu_blocks", [4])
@pytest.mark.parametrize("num_lookahead_slots", [3, 8, 10])
@pytest.mark.parametrize("enable_caching", [True, False])
def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
enable_caching):
""" Verify the block manager can correctly determine if a sequence group
can be swapped in/out.
"""
num_cpu_blocks = num_gpu_blocks
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
prompt, seq_group = create_dummy_prompt(
"1", prompt_length=(num_gpu_blocks - 1) * block_size - 1)
prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
prompt.status = SequenceStatus.RUNNING
# Swap seq group from GPU -> CPU.
gpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
mapping_keys = [key for key, _ in mapping]
assert mapping_keys == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
prompt.status = SequenceStatus.SWAPPED
# At this moment, we still have enough free blocks to swap in the seq group.
if num_lookahead_slots <= block_size:
assert block_manager.can_swap_in(seq_group,
num_lookahead_slots) == AllocStatus.OK
else:
assert block_manager.can_swap_in(
seq_group, num_lookahead_slots) == AllocStatus.NEVER
# During Swapped out, 2 cached blocks were evicted from the GPU,
# so the prompt1 can't be swapped in
prompt2_len = 2 * block_size - 1
prompt2, seq_group2 = create_dummy_prompt(
"2",
prompt_length=prompt2_len,
prompt_tokens=[10000 + i for i in range(prompt2_len)])
prompt2.status = SequenceStatus.WAITING
block_manager.allocate(seq_group2)
# Swap seq group from CPU -> GPU.
if num_lookahead_slots <= block_size:
assert block_manager.can_swap_in(
seq_group, num_lookahead_slots) == AllocStatus.LATER
else:
assert block_manager.can_swap_in(
seq_group, num_lookahead_slots) == AllocStatus.NEVER
# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.

View File

@ -100,3 +100,45 @@ class TestNaiveBlockAllocator:
for i, block in enumerate(blocks):
assert allocator.get_num_free_blocks() == i
allocator.free(block)
@staticmethod
@pytest.mark.parametrize("num_blocks", [4])
@pytest.mark.parametrize("block_size", [8])
def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
""" Verify the allocator can correctly return the number of
blocks touched, with different lookahead slots.
"""
allocator_src = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks,
block_size=block_size)
allocator_dst = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks,
block_size=block_size)
# Create a chain of cacheable blocks in the dst
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
"immutable",
allocator_src,
prev_block=None,
token_ids=list(range(block_size)))
src_blocks = [allocate_block() for _ in range(num_blocks - 1)]
# All blocks are cached
assert allocator_dst.get_num_blocks_touched(
src_blocks) == num_blocks - 1
# Insert one non-full block in the src
allocate_non_full_block = \
TestNaiveBlockAllocator.create_allocate_lambda(
"mutable", allocator_src,
prev_block=src_blocks[-1],token_ids=[]
)
src_blocks.append(allocate_non_full_block())
src_blocks[-1].append_token_ids([0])
assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=1) == num_blocks
assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=block_size - 1) == num_blocks
assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1)

View File

@ -315,6 +315,60 @@ class TestPrefixCachingBlockAllocator:
i)
allocator.free(block)
@staticmethod
@pytest.mark.parametrize("num_blocks", [4])
@pytest.mark.parametrize("block_size", [8])
def test_prefix_caching_block_get_num_blocks_touched(
num_blocks, block_size):
""" Verify the allocator can correctly return the number of
blocks touched, when there are cached prefixes and different
lookahead slots.
"""
allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
allocator_dst = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
# Create token ids that will exhaust all blocks except the last
token_ids = list(range((num_blocks - 1) * block_size))
# Create a chain of cacheable blocks in the dst
cached_blocks = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator_dst,
)
# Create a chain of the same blocks in the src
blocks_to_swap_in = \
TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator_src,
)
# All blocks are cached
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 0
# Free the first block in the dst
allocator_dst.free(cached_blocks[0])
# Now the first block becomes dangling, the swapped blocks need
# to reclaim the first block in the dst
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 1
# Insert one non-full block in the src
non_full_block = allocator_src.allocate_mutable_block(
blocks_to_swap_in[-1])
non_full_block.append_token_ids([0])
blocks_to_swap_in.append(non_full_block)
assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in,
num_lookahead_slots=1) == 2
assert allocator_dst.get_num_blocks_touched(
blocks_to_swap_in, num_lookahead_slots=block_size - 1) == 2
assert allocator_dst.get_num_blocks_touched(
blocks_to_swap_in, num_lookahead_slots=block_size) == 3
@staticmethod
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])

View File

@ -15,13 +15,15 @@ def create_dummy_prompt(
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
prompt_tokens: Optional[List[int]] = None,
) -> Tuple[Sequence, SequenceGroup]:
if not block_size:
block_size = prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
if prompt_tokens is None:
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id),
inputs={

View File

@ -307,9 +307,8 @@ class NaiveBlockAllocator(BlockAllocator):
# TODO(cade): make sure the logic is correct and clean it up.
for block in blocks:
if not block.is_full and num_lookahead_slots != 0:
if block.num_empty_slots >= num_lookahead_slots:
new_block_count += 1
else:
new_block_count += 1
if num_lookahead_slots > block.num_empty_slots:
new_block_count += cdiv(
num_lookahead_slots - block.num_empty_slots,
self._block_size)

View File

@ -579,14 +579,17 @@ class PrefixCachingBlockAllocator(BlockAllocator):
num_touched_blocks = 0
for block in blocks:
if not block.is_full:
if block.num_empty_slots >= num_lookahead_slots:
num_touched_blocks += 1
else:
num_touched_blocks += 1
if num_lookahead_slots > block.num_empty_slots:
num_touched_blocks += cdiv(
num_lookahead_slots - block.num_empty_slots,
self._block_size)
else:
if not self.is_block_cached(block):
# If the block has a match in the cache and the cached block
# is not referenced, then we still count it as a touched block
if not self.is_block_cached(block) or \
(block.content_hash is not None and \
self._cached_blocks[block.content_hash] in self.evictor):
num_touched_blocks += 1
return num_touched_blocks