Revert the changes in test_cache (#2335)
This commit is contained in:
parent
74d8d77626
commit
941767127c
@ -49,13 +49,12 @@ def test_copy_blocks(
|
|||||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||||
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||||
copy_src = []
|
block_mapping = {}
|
||||||
copy_dst = []
|
|
||||||
for i in range(num_mappings):
|
for i in range(num_mappings):
|
||||||
copy_src.append(src_blocks[i])
|
src = src_blocks[i]
|
||||||
copy_dst.append(dst_blocks[2 * i])
|
dst1 = dst_blocks[2 * i]
|
||||||
copy_src.append(src_blocks[i])
|
dst2 = dst_blocks[2 * i + 1]
|
||||||
copy_dst.append(dst_blocks[2 * i + 1])
|
block_mapping[src] = [dst1, dst2]
|
||||||
|
|
||||||
# Create the KV caches.
|
# Create the KV caches.
|
||||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||||
@ -67,14 +66,15 @@ def test_copy_blocks(
|
|||||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||||
|
|
||||||
# Call the copy blocks kernel.
|
# Call the copy blocks kernel.
|
||||||
cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst)
|
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
for src, dst in zip(copy_src, copy_dst):
|
for src, dsts in block_mapping.items():
|
||||||
for cloned_key_cache in cloned_key_caches:
|
for dst in dsts:
|
||||||
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
for cloned_key_cache in cloned_key_caches:
|
||||||
for cloned_value_cache in cloned_value_caches:
|
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
||||||
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
for cloned_value_cache in cloned_value_caches:
|
||||||
|
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user