[CI/Test] fix swap test for multi gpu (#4689)

This commit is contained in:
youkaichao 2024-05-08 13:14:02 -07:00 committed by GitHub
parent 20cfcdec99
commit 230c4b38c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -222,11 +222,12 @@ def test_reshape_and_cache_flash(
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device)
qkv = torch.randn(num_tokens,
3,
@ -245,6 +246,7 @@ def test_reshape_and_cache_flash(
head_size,
kv_cache_dtype,
dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0]