diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index 4ad46e5e..28739be1 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -1,7 +1,7 @@ #include #include #include - +#include #include #include "bgmv/bgmv_config.h" @@ -91,6 +91,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(w.size(2), h_out); CHECK_EQ(indicies.size(0), x.size(0)); CHECK_EQ(y.size(0), x.size(0)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; if (h_in < 65536 && h_out < 65536) { // TODO: See if we can get rid of this massive nested switch @@ -322,6 +323,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(w.size(2), h_out); CHECK_EQ(indicies.size(0), x.size(0)); CHECK_EQ(y.size(0), x.size(0)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; if (h_in < 65536 && h_out < 65536) { // TODO: See if we can get rid of this massive nested switch diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 6e05697f..2736a1c7 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -49,14 +49,18 @@ H1 = H2 = [ 32768, 33024 ] SEED = [0xabcdabcd987] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) @pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h2", H2) @pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_lora_correctness(dtype_str, h1, h2, seed): +def test_lora_correctness(dtype_str, h1, h2, seed, device): torch.manual_seed(seed) num_loras = 4 num_layers = 1 @@ -64,25 +68,15 @@ def test_lora_correctness(dtype_str, h1, h2, seed): bs = 32 scale = 0.123 dtype = getattr(torch, dtype_str) - device = torch.device("cuda") + torch.set_default_device(device) - wa_T_all = torch.randn(num_loras, - num_layers, - r, - h1, - dtype=dtype, - device=device) - wb_T_all = torch.randn(num_loras, - num_layers, - h2, - r, - dtype=dtype, - device=device) - indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) + wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long) for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype, device=device) - y = torch.randn(bs, h2, dtype=dtype, device=device) + x = torch.randn(bs, h1, dtype=dtype) + y = torch.randn(bs, h2, dtype=dtype) y_ref = y.clone() _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) @@ -98,8 +92,9 @@ def test_lora_correctness(dtype_str, h1, h2, seed): @pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h2", H2) @pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_lora_correctness_slice(dtype_str, h1, h2, seed): +def test_lora_correctness_slice(dtype_str, h1, h2, seed, device): if h2 % 3 != 0 or h2 // 3 not in H1: pytest.skip("h2 must be divisible by 3 and in supported shapes") torch.manual_seed(seed) @@ -109,50 +104,20 @@ def test_lora_correctness_slice(dtype_str, h1, h2, seed): bs = 32 scale = 0.123 dtype = getattr(torch, dtype_str) - device = torch.device("cuda") + torch.set_default_device(device) - wa_T_all_0 = torch.randn(num_loras, - num_layers, - r, - h1, - dtype=dtype, - device=device) - wa_T_all_1 = torch.randn(num_loras, - num_layers, - r, - h1, - dtype=dtype, - device=device) - wa_T_all_2 = torch.randn(num_loras, - num_layers, - r, - h1, - dtype=dtype, - device=device) - wb_T_all_0 = torch.randn(num_loras, - num_layers, - h2 // 3, - r, - dtype=dtype, - device=device) - wb_T_all_1 = torch.randn(num_loras, - num_layers, - h2 // 3, - r, - dtype=dtype, - device=device) - wb_T_all_2 = torch.randn(num_loras, - num_layers, - h2 // 3, - r, - dtype=dtype, - device=device) + wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) + wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) + wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) + wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) + wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) + wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long) for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype, device=device) - y = torch.randn(bs, h2, dtype=dtype, device=device) + x = torch.randn(bs, h1, dtype=dtype) + y = torch.randn(bs, h2, dtype=dtype) s = h2 // 3 y_ref = y.clone()