diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 8a27d51b..4cae15c7 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -222,11 +222,12 @@ def test_reshape_and_cache_flash( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) + torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) qkv = torch.randn(num_tokens, 3, @@ -245,6 +246,7 @@ def test_reshape_and_cache_flash( head_size, kv_cache_dtype, dtype, + device=device, ) key_cache, value_cache = key_caches[0], value_caches[0]