Revert the changes in test_cache (#2335)

This commit is contained in:
Woosuk Kwon 2024-01-03 17:32:05 -08:00 committed by GitHub
parent 74d8d77626
commit 941767127c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -49,13 +49,12 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
copy_src = [] block_mapping = {}
copy_dst = []
for i in range(num_mappings): for i in range(num_mappings):
copy_src.append(src_blocks[i]) src = src_blocks[i]
copy_dst.append(dst_blocks[2 * i]) dst1 = dst_blocks[2 * i]
copy_src.append(src_blocks[i]) dst2 = dst_blocks[2 * i + 1]
copy_dst.append(dst_blocks[2 * i + 1]) block_mapping[src] = [dst1, dst2]
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
@ -67,10 +66,11 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches] cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel. # Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst) cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
# Run the reference implementation. # Run the reference implementation.
for src, dst in zip(copy_src, copy_dst): for src, dsts in block_mapping.items():
for dst in dsts:
for cloned_key_cache in cloned_key_caches: for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src]) cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches: for cloned_value_cache in cloned_value_caches: