[CI/Test] fix swap test for multi gpu (#4689)
This commit is contained in:
parent
20cfcdec99
commit
230c4b38c1
@ -222,11 +222,12 @@ def test_reshape_and_cache_flash(
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda')
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device)
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
@ -245,6 +246,7 @@ def test_reshape_and_cache_flash(
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user